def decode_transformer(encoder_output, encoder_decoder_attention_bias, targets, hparams, name, task=None): """Original Transformer decoder.""" with tf.variable_scope(name): if task is None: task = hparams.task if task == "translate": targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_bias = ( transformer.transformer_prepare_decoder(targets, hparams)) decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output = transformer.transformer_decoder( decoder_input, encoder_output, decoder_self_bias, encoder_decoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, axis=2) else: assert task == "image" inputs = None # have to reshape targets as b, 32, 32, 3 * hidden size] beacuse otherwise # prepare_image will choke targets = tf.reshape(targets, [tf.shape(targets)[0], hparams.img_len, hparams.img_len, hparams.num_channels*hparams.hidden_size]) # Prepare decoder inputs and bias. decoder_input, _, _, bias = cia.prepare_decoder(targets, hparams) # Add class label to decoder input. if not hparams.drop_inputs: decoder_input += tf.reshape( inputs, [common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size]) decoder_output = cia.transformer_decoder_layers( decoder_input, None, bias, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape(decoder_output, [decoder_output_shape[0], -1, 1, hparams.hidden_size]) # Expand since t2t expects 4d tensors. return decoder_output
def decode_transformer(encoder_output, encoder_decoder_attention_bias, targets, hparams, name): """Original Transformer decoder.""" with tf.variable_scope(name): targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_bias = ( transformer.transformer_prepare_decoder(targets, hparams)) decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output = transformer.transformer_decoder( decoder_input, encoder_output, decoder_self_bias, encoder_decoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, axis=2) decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape( decoder_output, [decoder_output_shape[0], -1, 1, hparams.hidden_size]) # Expand since t2t expects 4d tensors. return decoder_output
def decode_transformer(encoder_output, encoder_decoder_attention_bias, targets, hparams, name, task=None, causal=True): """Original Transformer decoder.""" orig_hparams = hparams with tf.variable_scope(name): if task is None: task = hparams.task if task == "translate": targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_bias = ( transformer.transformer_prepare_decoder(targets, hparams)) decoder_input = tf.nn.dropout( decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) if not causal: decoder_self_bias *= 0. decoder_output = transformer.transformer_decoder( decoder_input, encoder_output, decoder_self_bias, encoder_decoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, axis=2) else: assert task == "image" inputs = None # have to reshape targets as b, 32, 32, 3 * hidden size] beacuse otherwise # prepare_image will choke targets = tf.reshape(targets, [ tf.shape(targets)[0], hparams.img_len, hparams.img_len, hparams.num_channels * hparams.hidden_size ]) # Prepare decoder inputs and bias. # TODO(nikip): Make prepare_decoder return bias decoder_input, _, _ = cia.prepare_decoder(targets, hparams) bias = None # Add class label to decoder input. if not hparams.drop_inputs: decoder_input += tf.reshape(inputs, [ common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size ]) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output=None, num_layers=hparams.num_decoder_layers or hparams.num_hidden_layers, hparams=hparams, self_attention_bias=bias, attention_type=hparams.dec_attention_type, name="decoder") decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape( decoder_output, [decoder_output_shape[0], -1, 1, hparams.hidden_size]) # Expand since t2t expects 4d tensors. hparams = orig_hparams return decoder_output
def decode(self, decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, cache=None, nonpadding=None): """Decode Transformer outputs from encoder representation. Args: decoder_input: inputs to bottom of the model. [batch_size, decoder_length, hidden_dim] encoder_output: Encoder representation. [batch_size, input_length, hidden_dim] encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder attention. [batch_size, input_length] decoder_self_attention_bias: Bias and mask weights for decoder self-attention. [batch_size, decoder_length] hparams: hyperparmeters for model. cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. nonpadding: optional Tensor with shape [batch_size, decoder_length] Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] """ decoder_input = tf.nn.dropout( decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output = transformer_decoder( decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams, cache=cache, nonpadding=nonpadding, save_weights_to=self.attention_weights) if (common_layers.is_on_tpu() and hparams.mode == tf.estimator.ModeKeys.TRAIN): # TPU does not react kindly to extra dimensions. # TODO(noam): remove this once TPU is more forgiving of extra dims. return decoder_output else: # Expand since t2t expects 4d tensors. m = self.sentence_cache.Query( tf.reshape(decoder_output, [hparams.batch_size, -1, hparams.hidden_size])) #m = tf.py_func(self.sentence_cache.QueryMultipleEntries, [decoder_output], tf.float32) lambd = self.calculate_mixing_weight( tf.reshape(decoder_output, [hparams.batch_size, -1, hparams.hidden_size]), m) m = tf.reshape(m, tf.shape(decoder_output)) lambd = tf.reshape( lambd, (tf.shape(decoder_output)[0], -1, hparams.hidden_size)) if self.hparams.use_cache: return tf.expand_dims(lambd * decoder_output + (1.0 - lambd) * m, axis=2) else: return tf.expand_dims(decoder_output, axis=2)
def encode_decode_task(features, hparams, train, attention_weights=None): """Model core graph for the one-shot action. Args: features: a dictionary contains "inputs" that is a tensor in shape of [batch_size, num_tokens], "verb_id_seq" that is in shape of [batch_size, num_actions], "object_spans" and "param_span" tensor in shape of [batch_size, num_actions, 2]. 0 is used as padding or non-existent values. hparams: the general hyperparameters for the model. train: the train mode. attention_weights: the dict to keep attention weights for analysis. Returns: loss_dict: the losses for training. prediction_dict: the predictions for action tuples. areas: the area encodings of the task. scope: the embedding scope. """ del train input_embeddings, scope = common_embed.embed_tokens( features["task"], hparams.task_vocab_size, hparams.hidden_size, hparams) with tf.variable_scope("encode_decode", reuse=tf.AUTO_REUSE): encoder_nonpadding = tf.minimum(tf.to_float(features["task"]), 1.0) input_embeddings = tf.multiply(tf.expand_dims(encoder_nonpadding, 2), input_embeddings) encoder_input, self_attention_bias, encoder_decoder_attention_bias = ( transformer.transformer_prepare_encoder(input_embeddings, None, hparams, features=None)) encoder_input = tf.nn.dropout(encoder_input, keep_prob=1.0 - hparams.layer_prepostprocess_dropout) if hparams.instruction_encoder == "transformer": encoder_output = transformer.transformer_encoder( encoder_input, self_attention_bias, hparams, save_weights_to=attention_weights, make_image_summary=not common_layers.is_xla_compiled()) else: raise ValueError("Unsupported instruction encoder %s" % (hparams.instruction_encoder)) span_rep = hparams.get("span_rep", "area") area_encodings, area_starts, area_ends = area_utils.compute_sum_image( encoder_output, max_area_width=hparams.max_span) current_shape = tf.shape(area_encodings) if span_rep == "area": area_encodings, _, _ = area_utils.compute_sum_image( encoder_output, max_area_width=hparams.max_span) elif span_rep == "basic": area_encodings = area_utils.compute_alternative_span_rep( encoder_output, input_embeddings, max_area_width=hparams.max_span, hidden_size=hparams.hidden_size, advanced=False) elif span_rep == "coref": area_encodings = area_utils.compute_alternative_span_rep( encoder_output, input_embeddings, max_area_width=hparams.max_span, hidden_size=hparams.hidden_size, advanced=True) else: raise ValueError("xyz") areas = {} areas["encodings"] = area_encodings areas["starts"] = area_starts areas["ends"] = area_ends with tf.control_dependencies([ tf.print("encoder_output", tf.shape(encoder_output)), tf.assert_equal(current_shape, tf.shape(area_encodings), summarize=100) ]): paddings = tf.cast(tf.less(self_attention_bias, -1), tf.int32) padding_sum, _, _ = area_utils.compute_sum_image( tf.expand_dims(tf.squeeze(paddings, [1, 2]), 2), max_area_width=hparams.max_span) num_areas = common_layers.shape_list(area_encodings)[1] area_paddings = tf.reshape(tf.minimum(tf.to_float(padding_sum), 1.0), [-1, num_areas]) areas["bias"] = area_paddings decoder_nonpadding = tf.to_float( tf.greater(features["verb_refs"][:, :, 1], features["verb_refs"][:, :, 0])) if hparams.instruction_encoder == "lstm": hparams_decoder = copy.copy(hparams) hparams_decoder.set_hparam("pos", "none") else: hparams_decoder = hparams decoder_input, decoder_self_attention_bias = _prepare_decoder_input( area_encodings, decoder_nonpadding, features, hparams_decoder, embed_scope=scope) decoder_input = tf.nn.dropout(decoder_input, keep_prob=1.0 - hparams.layer_prepostprocess_dropout) if hparams.instruction_decoder == "transformer": decoder_output = transformer.transformer_decoder( decoder_input=decoder_input, encoder_output=encoder_output, decoder_self_attention_bias=decoder_self_attention_bias, encoder_decoder_attention_bias=encoder_decoder_attention_bias, hparams=hparams_decoder) else: raise ValueError("Unsupported instruction encoder %s" % (hparams.instruction_encoder)) return decoder_output, decoder_nonpadding, areas, scope
def decode_srcs_to_trgs(self, trg_emb, trg_input_ids=None, outputs=None): trg_emb = common_attention.add_timing_signal_1d(trg_emb) trg_emb_fn = None control_flatten_outputs = None control_flatten_bias = None if 'control_vec' in self.shared_tensors and self.flags.control_mode: if "flatten" not in self.flags.control_mode: if 'bert' in self.flags.model_mode: # In BERT, update trg emb inside bert trg_emb_fn = lambda trg_emb: self.update_embedding( trg_emb, ) else: trg_emb = self.update_embedding(trg_emb) else: control_flatten_outputs = self.shared_tensors['control_vec'] control_flatten_bias = tf.zeros([1, 1, 1, 1]) control_outputs, control_bias = None, None if 'control_outputs' in self.shared_tensors: control_outputs = self.shared_tensors['control_outputs'] control_bias = self.shared_tensors['control_bias'] trg_length = tf.shape(trg_emb)[1] if 'gpt2' in self.flags.model_mode: trg_outputs = model.gpt2_decoder( self.hparams, trg_emb, encoder_outputs=self.shared_tensors['src_outputs'], encoder_bias=self.shared_tensors['src_bias']) elif 't2t' in self.flags.model_mode: trg_self_attention_bias = common_attention.attention_bias_lower_triangle( trg_length) trg_outputs = transformer.transformer_decoder( decoder_input=trg_emb, decoder_self_attention_bias=trg_self_attention_bias, encoder_output=self.shared_tensors['src_outputs'], encoder_decoder_attention_bias=self.shared_tensors['src_bias'], hparams=self.hparams, external_output=control_outputs, external_bias=control_bias, external_output2=control_flatten_outputs, external_bias2=control_flatten_bias, external_output3=self.shared_tensors['template_simp_outputs'], external_bias3=self.shared_tensors['template_simp_bias'], name='trg_decoder') elif 'bert' in self.flags.model_mode: trg_mask = common_attention.attention_bias_bert(trg_length, -1, 0) bert_model = BertModel( config=BertConfig.from_json_file(self.flags.bert_config_file), is_training=self.is_training, input_ids=trg_input_ids, input_mask=trg_mask, embeddings=self.shared_tensors['word_embedding_table'], encoder_ids=self.shared_tensors['src_ids'], encoder_outpus=self.shared_tensors['src_outputs'], encoder_mask=1.0 - self.shared_tensors['src_mask'], trg_emb_fn=trg_emb_fn) trg_outputs = bert_model.get_sequence_output() else: raise ValueError('model_mode not known') return trg_outputs
def vae_transformer_internal(inputs, targets, target_space, hparams): """VAE Transformer, main step used for training.""" with tf.variable_scope("vae_transformer"): is_training = hparams.mode == tf.contrib.learn.ModeKeys.TRAIN # Prepare inputs, targets, and k. inputs = common_layers.flatten4d3d(inputs) targets = common_layers.flatten4d3d(targets) k = 2**hparams.num_compress_steps _, targets = common_layers.pad_to_same_length( inputs, targets, final_length_divisible_by=k) # Transformer preparations and encoder. (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias ) = transformer.transformer_prepare_encoder(inputs, target_space, hparams) residual_fn = transformer.get_residual_fn(hparams) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.residual_dropout) encoder_output = transformer.transformer_encoder( encoder_input, residual_fn, encoder_self_attention_bias, hparams) def get_decoder_autoregressive(): """Decoder input for autoregressive computation.""" (a, b) = transformer.transformer_prepare_decoder(targets, hparams) return (a, b, tf.constant(0.0)) # 10% of the time we compress all-zeros, as will be at decoding start. prob_targets = 0.9 if is_training else 1.0 to_compress = tf.cond(tf.less(tf.random_uniform([]), prob_targets), lambda: targets, lambda: tf.zeros_like(targets)) z, kl_loss = compress_vae(to_compress, hparams, "vae") # Decompress. for i in xrange(hparams.num_compress_steps): j = hparams.num_hidden_layers - i - 1 z = decompress(z, hparams, "decompress_%d" % j) def get_decoder_from_vae(): """Decoder input computed by VAE.""" # Return decoder stuff. (a, b) = transformer.transformer_prepare_decoder( tf.squeeze(z, axis=2), hparams) return (a, b, kl_loss) # Randomize decoder inputs.. prob_do_vae = common_layers.inverse_exp_decay(40000) * 0.7 step = tf.to_float(tf.contrib.framework.get_global_step()) if not is_training: prob_do_vae = tf.cond(tf.less(step, 40000.0), lambda: tf.constant(0.0), lambda: tf.constant(1.0)) (decoder_input, decoder_self_attention_bias, kl_loss2) = tf.cond(tf.less(tf.random_uniform([]), prob_do_vae), get_decoder_from_vae, get_decoder_autoregressive) # Transformer decoder. decoder_output = transformer.transformer_decoder( decoder_input, encoder_output, residual_fn, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, 2) cond_self = tf.cond(tf.less(step, 30000.0), lambda: tf.constant(1.0), lambda: tf.constant(0.0)) prob_self = 0.4 if is_training else cond_self (ret, kl_loss) = tf.cond(tf.less(tf.random_uniform([]), prob_self), lambda: (z, kl_loss), lambda: (decoder_output, kl_loss2)) kl_loss *= common_layers.inverse_exp_decay(50000) * 2.0 return ret, kl_loss