def lstm_seq2seq_internal_bid_encoder(inputs, targets, hparams, train): """The basic LSTM seq2seq model with bidirectional encoder.""" with tf.variable_scope("lstm_seq2seq_bid_encoder"): if inputs is not None: inputs_length = common_layers.length_from_embedding(inputs) # Flatten inputs. inputs = common_layers.flatten4d3d(inputs) # LSTM encoder. _, final_encoder_state = lstm_bid_encoder( inputs, inputs_length, hparams, train, "encoder") else: inputs_length = None final_encoder_state = None # LSTM decoder. shifted_targets = common_layers.shift_right(targets) # Add 1 to account for the padding added to the left from shift_right targets_length = common_layers.length_from_embedding(shifted_targets) + 1 hparams_decoder = copy.copy(hparams) hparams_decoder.hidden_size = 2 * hparams.hidden_size decoder_outputs, _ = lstm( common_layers.flatten4d3d(shifted_targets), targets_length, hparams_decoder, train, "decoder", initial_state=final_encoder_state) return tf.expand_dims(decoder_outputs, axis=2)
def lstm_seq2seq_internal(inputs, targets, hparams, train): """The basic LSTM seq2seq model, main step used for training.""" with tf.variable_scope("lstm_seq2seq"): if inputs is not None: inputs_length = common_layers.length_from_embedding(inputs) # Flatten inputs. inputs = common_layers.flatten4d3d(inputs) # LSTM encoder. inputs = tf.reverse_sequence(inputs, inputs_length, seq_axis=1) _, final_encoder_state = lstm(inputs, inputs_length, hparams, train, "encoder") else: final_encoder_state = None # LSTM decoder. shifted_targets = common_layers.shift_right(targets) # Add 1 to account for the padding added to the left from shift_right targets_length = common_layers.length_from_embedding(shifted_targets) + 1 decoder_outputs, _ = lstm( common_layers.flatten4d3d(shifted_targets), targets_length, hparams, train, "decoder", initial_state=final_encoder_state) return tf.expand_dims(decoder_outputs, axis=2)
def body(self, features): hparams = self._hparams targets = features["targets"] inputs = features["inputs"] target_space = features["target_space_id"] inputs = common_layers.flatten4d3d(inputs) targets = common_layers.flatten4d3d(targets) (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias) = (transformer.transformer_prepare_encoder( inputs, target_space, hparams)) (decoder_input, decoder_self_attention_bias) = transformer.transformer_prepare_decoder( targets, hparams) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = transformer_revnet_encoder( encoder_input, encoder_self_attention_bias, hparams) decoder_output = transformer_revnet_decoder( decoder_input, encoder_output, decoder_self_attention_bias, encoder_decoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, 2) return decoder_output
def body(self, features): hp = self.hparams # pylint: disable=eval-used if hp.image_input_type == "image": image_feat = vqa_layers.image_embedding( features["inputs"], model_fn=eval(hp.image_model_fn), trainable=hp.train_resnet, is_training=hp.mode == tf.estimator.ModeKeys.TRAIN) else: image_feat = features["inputs"] image_feat = common_layers.flatten4d3d(image_feat) image_feat = common_layers.dense(image_feat, hp.hidden_size) utils.collect_named_outputs("norms", "image_feat_after_proj", tf.norm(image_feat, axis=-1)) question = common_layers.flatten4d3d(features["question"]) utils.collect_named_outputs("norms", "question_embedding", tf.norm(question, axis=-1)) (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias) = prepare_image_question_encoder( image_feat, question, hp) encoder_input = tf.nn.dropout( encoder_input, keep_prob=1.-hp.layer_prepostprocess_dropout) encoder_output, _ = recurrent_transformer_decoder( encoder_input, None, encoder_self_attention_bias, None, hp, name="encoder") utils.collect_named_outputs( "norms", "encoder_output", tf.norm(encoder_output, axis=-1)) # scale query by sqrt(hidden_size) query = tf.get_variable("query", [hp.hidden_size]) * hp.hidden_size **0.5 query = tf.expand_dims(tf.expand_dims(query, axis=0), axis=0) batch_size = common_layers.shape_list(encoder_input)[0] query = tf.tile(query, [batch_size, 1, 1]) query = tf.nn.dropout( query, keep_prob=1.-hp.layer_prepostprocess_dropout) decoder_output, _ = recurrent_transformer_decoder( query, encoder_output, None, encoder_decoder_attention_bias, hp, name="decoder") utils.collect_named_outputs("norms", "decoder_output", tf.norm(decoder_output, axis=-1)) norm_tensors = utils.convert_collection_to_dict("norms") vqa_layers.summarize_tensors(norm_tensors, tag="norms/") # Expand dimension 1 and 2 return tf.expand_dims(decoder_output, axis=1)
def lstm_seq2seq_internal_attention(inputs, targets, hparams, train): """LSTM seq2seq model with attention, main step used for training.""" with tf.variable_scope("lstm_seq2seq_attention"): # Flatten inputs. inputs = common_layers.flatten4d3d(inputs) # LSTM encoder. encoder_outputs, final_encoder_state = lstm( tf.reverse(inputs, axis=[1]), hparams, train, "encoder") # LSTM decoder with attention shifted_targets = common_layers.shift_right(targets) decoder_outputs, _ = lstm_attention_decoder( common_layers.flatten4d3d(shifted_targets), hparams, train, "decoder", final_encoder_state, encoder_outputs) return tf.expand_dims(decoder_outputs, axis=2)
def bytenet_internal(inputs, targets, hparams): """ByteNet, main step used for training.""" with tf.variable_scope("bytenet"): # Flatten inputs and extend length by 50%. inputs = tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2) extend_length = tf.to_int32(0.5 * tf.to_float(tf.shape(inputs)[1])) inputs_shape = inputs.shape.as_list() inputs = tf.pad(inputs, [[0, 0], [0, extend_length], [0, 0], [0, 0]]) inputs_shape[1] = None inputs.set_shape(inputs_shape) # Don't lose the other shapes when padding. # Pad inputs and targets to be the same length, divisible by 50. inputs, targets = common_layers.pad_to_same_length( inputs, targets, final_length_divisible_by=50) final_encoder = residual_dilated_conv(inputs, hparams.num_block_repeat, "SAME", "encoder", hparams) shifted_targets = common_layers.shift_right(targets) kernel = (hparams.kernel_height, hparams.kernel_width) decoder_start = common_layers.conv_block( tf.concat([final_encoder, shifted_targets], axis=3), hparams.hidden_size, [((1, 1), kernel)], padding="LEFT") return residual_dilated_conv(decoder_start, hparams.num_block_repeat, "LEFT", "decoder", hparams)
def slicenet_internal(inputs, targets, target_space, hparams, run_decoder=True): """The slicenet model, main step used for training.""" with tf.variable_scope("slicenet"): # Project to hidden size if necessary if inputs.get_shape().as_list()[-1] != hparams.hidden_size: inputs = common_layers.conv_block( inputs, hparams.hidden_size, [((1, 1), (3, 3))], first_relu=False, padding="SAME", force2d=True) # Flatten inputs and encode. inputs = tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2) inputs_mask = 1.0 - embedding_to_padding(inputs) inputs = common_layers.add_timing_signal(inputs) # Add position info. target_space_emb = embed_target_space(target_space, hparams.hidden_size) extra_layers = int(hparams.num_hidden_layers * 1.5) inputs_encoded = multi_conv_res( inputs, "SAME", "encoder", extra_layers, hparams, mask=inputs_mask) if not run_decoder: return inputs_encoded # Do the middle part. decoder_start, similarity_loss = slicenet_middle( inputs_encoded, targets, target_space_emb, inputs_mask, hparams) # Decode. decoder_final = multi_conv_res( decoder_start, "LEFT", "decoder", hparams.num_hidden_layers, hparams, mask=inputs_mask, source=inputs_encoded) return decoder_final, tf.reduce_mean(similarity_loss)
def encode(self, inputs, target_space, hparams, features=None): """Encode transformer inputs. Args: inputs: Transformer inputs [batch_size, input_length, input_height, hidden_dim] which will be flattened along the two spatial dimensions. target_space: scalar, target space ID. hparams: hyperparmeters for model. features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: Tuple of: encoder_output: Encoder representation. [batch_size, input_length, hidden_dim] encoder_decoder_attention_bias: Bias and mask weights for encodre-decoder attention. [batch_size, input_length] """ inputs = common_layers.flatten4d3d(inputs) encoder_input, self_attention_bias, encoder_decoder_attention_bias = ( transformer_prepare_encoder( inputs, target_space, hparams, features=features)) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = transformer_encoder( encoder_input, self_attention_bias, hparams, nonpadding=features_to_nonpadding(features, "inputs"), save_weights_to=self.attention_weights) return encoder_output, encoder_decoder_attention_bias
def preprocess_targets(targets, i): """Performs preprocessing steps on the targets to prepare for the decoder. This includes: - Embedding the ids. - Flattening to 3D tensor. - Optionally adding timing signals. Args: targets: inputs ids to the decoder. [batch_size, 1] i: scalar, Step number of the decoding loop. Returns: Processed targets [batch_size, 1, hidden_dim] """ # _shard_features called to ensure that the variable names match targets = self._shard_features({"targets": targets})["targets"] with tf.variable_scope(target_modality.name): targets = target_modality.targets_bottom_sharded(targets, dp)[0] targets = common_layers.flatten4d3d(targets) # TODO(llion): Explain! Is this even needed? targets = tf.cond( tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets) if hparams.pos == "timing": targets += timing_signal[:, i:i + 1] return targets
def transformer_text_encoder(inputs, target_space, hparams, name=None): """Transformer text encoder over inputs with unmasked full attention. Args: inputs: Tensor of shape [batch, length, 1, hparams.hidden_size]. target_space: int. Used for encoding inputs under a target space id. hparams: tf.contrib.training.HParams. name: string, variable scope. Returns: encoder_output: Tensor of shape [batch, length, hparams.hidden_size]. ed: Tensor of shape [batch, 1, 1, length]. Encoder-decoder attention bias for any padded tokens. """ with tf.variable_scope(name, default_name="transformer_text_encoder"): inputs = common_layers.flatten4d3d(inputs) [ encoder_input, encoder_self_attention_bias, ed, ] = transformer_layers.transformer_prepare_encoder( inputs, target_space=target_space, hparams=hparams) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout) encoder_output = transformer_layers.transformer_encoder( encoder_input, encoder_self_attention_bias, hparams) return encoder_output, ed
def model_fn_body(self, features): """Transformer main model_fn. Args: features: Map of features to the model. Should contain the following: "inputs": Transformer inputs [batch_size, input_length, hidden_dim] "tragets": Target decoder outputs. [batch_size, decoder_length, hidden_dim] "target_space_id" Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] """ hparams = self._hparams inputs = features.get("inputs") encoder_output, encoder_decoder_attention_bias = (None, None) if inputs is not None: target_space = features["target_space_id"] encoder_output, encoder_decoder_attention_bias = self.encode( inputs, target_space, hparams, features=features) targets = features["targets"] targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_attention_bias = transformer_prepare_decoder( targets, hparams, features=features) return self.decode(decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, nonpadding=_features_to_nonpadding(features, "targets"))
def transformer_text_encoder(x, space_id, hparams, name="transformer_text_encoder"): """Transformer text encoder over inputs with unmasked full attention. Args: x: Tensor of shape [batch, length, 1, hparams.hidden_size]. space_id: int, id. hparams: tf.contrib.training.HParams. name: string, variable scope. Returns: encoder_output: Tensor of shape [batch, length, hparams.hidden_size]. ed: Tensor of shape [batch, 1, 1, length]. Encoder-decoder attention bias for any padded tokens. """ with tf.variable_scope(name): x = common_layers.flatten4d3d(x) (encoder_input, encoder_self_attention_bias, ed) = transformer.transformer_prepare_encoder(x, space_id, hparams) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout) encoder_output = transformer.transformer_encoder( encoder_input, encoder_self_attention_bias, hparams) return encoder_output, ed
def body(self, features): hp = self.hparams # pylint: disable=eval-used if hp.image_input_type == "image": image_feat = vqa_layers.image_embedding( features["inputs"], model_fn=eval(hp.image_model_fn), trainable=hp.train_resnet, is_training=hp.mode == tf.estimator.ModeKeys.TRAIN) else: image_feat = features["inputs"] image_feat = common_layers.flatten4d3d(image_feat) # image feature self attention # image_feat = tf.nn.dropout( # image_feat, keep_prob=1.-hp.layer_prepostprocess_dropout) # image_feat = image_feat - tf.reduce_mean( # image_feat, axis=-1, keepdims=True) # image_feat = tf.nn.l2_normalize(image_feat, -1) # utils.collect_named_outputs("norms", "image_feat_after_l2", # tf.norm(image_feat, axis=-1)) image_feat = tf.nn.dropout(image_feat, keep_prob=1.-hp.dropout) image_feat = image_encoder(image_feat, hp) utils.collect_named_outputs("norms", "image_feat_encoded", tf.norm(image_feat, axis=-1)) image_feat = common_layers.l2_norm(image_feat) utils.collect_named_outputs("norms", "image_feat_encoded_l2", tf.norm(image_feat, axis=-1)) query = question_encoder(features["question"], hp) utils.collect_named_outputs("norms", "query", tf.norm(query, axis=-1)) image_ave = attn(image_feat, query, hp) utils.collect_named_outputs("norms", "image_ave", tf.norm(image_ave, axis=-1)) image_question = tf.concat([image_ave, query], axis=1) utils.collect_named_outputs("norms", "image_question", tf.norm(image_question, axis=-1)) image_question = tf.nn.dropout(image_question, 1. - hp.dropout) output = mlp(image_question, hp) utils.collect_named_outputs("norms", "output", tf.norm(output, axis=-1)) norm_tensors = utils.convert_collection_to_dict("norms") vqa_layers.summarize_tensors(norm_tensors, tag="norms/") # Expand dimension 1 and 2 return tf.expand_dims(tf.expand_dims(output, axis=1), axis=2)
def decode_transformer(encoder_output, encoder_decoder_attention_bias, targets, hparams, name, task=None): """Original Transformer decoder.""" with tf.variable_scope(name): if task is None: task = hparams.task if task == "translate": targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_bias = ( transformer.transformer_prepare_decoder(targets, hparams)) decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output = transformer.transformer_decoder( decoder_input, encoder_output, decoder_self_bias, encoder_decoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, axis=2) else: assert task == "image" inputs = None # have to reshape targets as b, 32, 32, 3 * hidden size] beacuse otherwise # prepare_image will choke targets = tf.reshape(targets, [tf.shape(targets)[0], hparams.img_len, hparams.img_len, hparams.num_channels*hparams.hidden_size]) # Prepare decoder inputs and bias. decoder_input, _, _, bias = cia.prepare_decoder(targets, hparams) # Add class label to decoder input. if not hparams.drop_inputs: decoder_input += tf.reshape( inputs, [common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size]) decoder_output = cia.transformer_decoder_layers( decoder_input, None, bias, hparams.num_decoder_layers or hparams.num_hidden_layers, hparams, attention_type=hparams.dec_attention_type, name="decoder") decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape(decoder_output, [decoder_output_shape[0], -1, 1, hparams.hidden_size]) # Expand since t2t expects 4d tensors. return decoder_output
def body(self, features): if self._hparams.initializer == "orthogonal": raise ValueError("LSTM models fail with orthogonal initializer.") train = self._hparams.mode == tf.estimator.ModeKeys.TRAIN inputs = features.get("inputs") # Flatten inputs. inputs = common_layers.flatten4d3d(inputs) # LSTM encoder. encoder_output, _ = lstm( tf.reverse(inputs, axis=[1]), self._hparams, train, "encoder") return tf.expand_dims(encoder_output, axis=2)
def lstm_seq2seq_internal(inputs, targets, hparams, train): """The basic LSTM seq2seq model, main step used for training.""" with tf.variable_scope("lstm_seq2seq"): if inputs is not None: # Flatten inputs. inputs = common_layers.flatten4d3d(inputs) # LSTM encoder. _, final_encoder_state = lstm( tf.reverse(inputs, axis=[1]), hparams, train, "encoder") else: final_encoder_state = None # LSTM decoder. shifted_targets = common_layers.shift_right(targets) decoder_outputs, _ = lstm( common_layers.flatten4d3d(shifted_targets), hparams, train, "decoder", initial_state=final_encoder_state) return tf.expand_dims(decoder_outputs, axis=2)
def lstm_seq2seq_internal_attention(inputs, targets, hparams, train, inputs_length, targets_length): """LSTM seq2seq model with attention, main step used for training.""" with tf.variable_scope("lstm_seq2seq_attention"): # Flatten inputs. inputs = common_layers.flatten4d3d(inputs) # LSTM encoder. inputs = tf.reverse_sequence(inputs, inputs_length, seq_axis=1) encoder_outputs, final_encoder_state = lstm( inputs, inputs_length, hparams, train, "encoder") # LSTM decoder with attention. shifted_targets = common_layers.shift_right(targets) # Add 1 to account for the padding added to the left from shift_right targets_length = targets_length + 1 decoder_outputs = lstm_attention_decoder( common_layers.flatten4d3d(shifted_targets), hparams, train, "decoder", final_encoder_state, encoder_outputs, inputs_length, targets_length) return tf.expand_dims(decoder_outputs, axis=2)
def lstm_seq2seq_internal_attention_bid_encoder(inputs, targets, hparams, train): """LSTM seq2seq model with attention, main step used for training.""" with tf.variable_scope("lstm_seq2seq_attention_bid_encoder"): inputs_length = common_layers.length_from_embedding(inputs) # Flatten inputs. inputs = common_layers.flatten4d3d(inputs) # LSTM encoder. encoder_outputs, final_encoder_state = lstm_bid_encoder( inputs, inputs_length, hparams, train, "encoder") # LSTM decoder with attention shifted_targets = common_layers.shift_right(targets) # Add 1 to account for the padding added to the left from shift_right targets_length = common_layers.length_from_embedding(shifted_targets) + 1 hparams_decoder = copy.copy(hparams) hparams_decoder.hidden_size = 2 * hparams.hidden_size decoder_outputs = lstm_attention_decoder( common_layers.flatten4d3d(shifted_targets), hparams_decoder, train, "decoder", final_encoder_state, encoder_outputs, inputs_length, targets_length) return tf.expand_dims(decoder_outputs, axis=2)
def _prepare_decoder(self, targets): """Process the transformer decoder input.""" targets = common_layers.flatten4d3d(targets) output = transformer.transformer_prepare_decoder( targets, self._hparams, features=None, ) deco_input, deco_self_attention_bias = output deco_input = tf.nn.dropout( deco_input, 1.0 - self._hparams.layer_prepostprocess_dropout ) return deco_input, deco_self_attention_bias
def lstm_seq2seq_internal_bid_encoder(inputs, targets, hparams, train): """The basic LSTM seq2seq model with bidirectional encoder.""" with tf.variable_scope("lstm_seq2seq_bid_encoder"): if inputs is not None: # Flatten inputs. inputs = common_layers.flatten4d3d(inputs) # LSTM encoder. _, final_encoder_state = lstm_bid_encoder( tf.reverse(inputs, axis=[1]), hparams, train, "encoder") else: final_encoder_state = None # LSTM decoder. shifted_targets = common_layers.shift_right(targets) hparams_decoder = copy.copy(hparams) hparams_decoder.hidden_size = 2 * hparams.hidden_size decoder_outputs, _ = lstm( common_layers.flatten4d3d(shifted_targets), hparams_decoder, train, "decoder", initial_state=final_encoder_state) return tf.expand_dims(decoder_outputs, axis=2)
def body(self, features): if self._hparams.initializer == "orthogonal": raise ValueError("LSTM models fail with orthogonal initializer.") train = self._hparams.mode == tf.estimator.ModeKeys.TRAIN inputs = features.get("inputs") inputs_length = common_layers.length_from_embedding(inputs) # Flatten inputs. inputs = common_layers.flatten4d3d(inputs) # LSTM encoder. inputs = tf.reverse_sequence(inputs, inputs_length, seq_axis=1) encoder_output, _ = lstm(inputs, inputs_length, self._hparams, train, "encoder") return tf.expand_dims(encoder_output, axis=2)
def question_encoder(question, hparams, name="encoder"): """Question encoder, run LSTM encoder and get the last output as encoding.""" with tf.variable_scope(name, "encoder", values=[question]): question = common_layers.flatten4d3d(question) padding = common_attention.embedding_to_padding(question) length = common_attention.padding_to_length(padding) max_question_length = hparams.max_question_length question = question[:, :max_question_length, :] actual_question_length = common_layers.shape_list(question)[1] length = tf.minimum(length, max_question_length) padding = [[0, 0], [0, max_question_length-actual_question_length], [0, 0]] question = tf.pad(question, padding) question_shape = question.get_shape().as_list() question_shape[1] = max_question_length question.set_shape(question_shape) # apply tanh dropout on question embedding question = tf.tanh(question) question = tf.nn.dropout(question, keep_prob=1.-hparams.dropout) question = [question[:, i, :] for i in range(max_question_length)] # rnn_layers = [_get_rnn_cell(hparams) # for _ in range(hparams.num_rnn_layers)] # rnn_multi_cell = tf.contrib.rnn.MultiRNNCell(rnn_layers) rnn_cell = _get_rnn_cell(hparams) # outputs, _ = tf.nn.dynamic_rnn( # rnn_cell, question, length, dtype=tf.float32) _, state = tf.nn.static_rnn(rnn_cell, question, sequence_length=length, dtype=tf.float32) # outputs = [tf.expand_dims(output, axis=1) for output in outputs] # outputs = tf.concat(outputs, axis=1) # utils.collect_named_outputs("vqa_attention_debug", "question_output", # outputs) # utils.collect_named_outputs("vqa_attention_debug", "question_state", # state.h) # batch_size = common_layers.shape_list(outputs)[0] # row_indices = tf.range(batch_size) # # length - 1 as index # indices = tf.transpose([row_indices, tf.maximum(length-1, 0)]) # last_output = tf.gather_nd(outputs, indices) # utils.collect_named_outputs("vqa_attention_debug", # "question_final_output", last_output) return state.h
def slicenet_middle(inputs_encoded, targets, target_space_emb, mask, hparams): """Middle part of slicenet, connecting encoder and decoder.""" def norm_fn(x, name): with tf.variable_scope(name, default_name="norm"): return common_layers.apply_norm(x, hparams.norm_type, hparams.hidden_size, hparams.norm_epsilon) # Flatten targets and embed target_space_id. targets_flat = tf.expand_dims(common_layers.flatten4d3d(targets), axis=2) target_space_emb = tf.tile(target_space_emb, [tf.shape(targets_flat)[0], 1, 1, 1]) # Calculate similarity loss (but don't run if not needed). if len(hparams.problems) > 1 and hparams.sim_loss_mult > 0.00001: targets_timed = common_layers.add_timing_signal(targets_flat) extra_layers = int(hparams.num_hidden_layers * 1.5) with tf.variable_scope(tf.get_variable_scope(), reuse=True): targets_encoded = multi_conv_res(targets_timed, "SAME", "encoder", extra_layers, hparams) with tf.variable_scope("similarity_loss"): similarity_loss = similarity_cost(inputs_encoded, targets_encoded) similarity_loss *= hparams.sim_loss_mult else: similarity_loss = 0.0 # Use attention from each target to look at input and retrieve. targets_shifted = common_layers.shift_right( targets_flat, pad_value=target_space_emb) if hparams.attention_type == "none": targets_with_attention = tf.zeros_like(targets_shifted) else: inputs_padding_bias = (1.0 - mask) * -1e9 # Bias to not attend to padding. targets_with_attention = attention( targets_shifted, inputs_encoded, norm_fn, hparams, bias=inputs_padding_bias) # Positional targets: merge attention and raw. kernel = (hparams.kernel_height, hparams.kernel_width) targets_merged = common_layers.subseparable_conv_block( tf.concat([targets_with_attention, targets_shifted], axis=3), hparams.hidden_size, [((1, 1), kernel)], normalizer_fn=norm_fn, padding="LEFT", separability=4, name="targets_merge") return targets_merged, similarity_loss
def lstm_seq2seq_internal_attention(inputs, targets, hparams, train): """LSTM seq2seq model with attention, main step used for training.""" with tf.variable_scope("lstm_seq2seq_attention"): # This is a temporary fix for varying-length sequences within in a batch. # A more complete fix should pass a length tensor from outside so that # all the lstm variants can use it. inputs_length = common_layers.length_from_embedding(inputs) # Flatten inputs. inputs = common_layers.flatten4d3d(inputs) # LSTM encoder. inputs = tf.reverse_sequence(inputs, inputs_length, seq_axis=1) encoder_outputs, final_encoder_state = lstm( inputs, inputs_length, hparams, train, "encoder") # LSTM decoder with attention. shifted_targets = common_layers.shift_right(targets) # Add 1 to account for the padding added to the left from shift_right targets_length = common_layers.length_from_embedding(shifted_targets) + 1 decoder_outputs = lstm_attention_decoder( common_layers.flatten4d3d(shifted_targets), hparams, train, "decoder", final_encoder_state, encoder_outputs, inputs_length, targets_length) return tf.expand_dims(decoder_outputs, axis=2)
def targets_bottom(self, inputs): with tf.variable_scope(self.name): # Reshape inputs to 2-d tensor and embed the RGB pixel values. ret = common_layers.embedding( tf.to_int32(common_layers.flatten4d3d(inputs)), self.top_dimensionality, self._body_input_depth, name="input_rgb_embedding") if self._model_hparams.multiply_embedding_mode == "sqrt_depth": ret *= self._body_input_depth**0.5 reshape_shape = common_layers.shape_list(inputs)[:3] reshape_shape.append(self._body_input_depth * 3) ret = tf.reshape(ret, reshape_shape) return tf.layers.dense(ret, self._body_input_depth)
def _prepare_encoder(self, inputs, target_space): """Process the transformer encoder inputs.""" inputs = common_layers.flatten4d3d(inputs) output = transformer.transformer_prepare_encoder( inputs, target_space, self._hparams, features=None, ) enco_input, enco_self_att_bias, enco_deco_att_bias = output enco_input = tf.nn.dropout( enco_input, 1.0 - self._hparams.layer_prepostprocess_dropout) return enco_input, enco_self_att_bias, enco_deco_att_bias
def encode(self, inputs, target_space, hparams, features=None, losses=None): """Encode Universal Transformer inputs. It is similar to "transformer.encode", but it uses "universal_transformer_util.universal_transformer_encoder" instead of "transformer.transformer_encoder". Args: inputs: Transformer inputs [batch_size, input_length, input_height, hidden_dim] which will be flattened along the two spatial dimensions. target_space: scalar, target space ID. hparams: hyperparmeters for model. features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. losses: Unused. Returns: Tuple of: encoder_output: Encoder representation. [batch_size, input_length, hidden_dim] encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder attention. [batch_size, input_length] encoder_extra_output: which is extra encoder output used in some variants of the model (e.g. in ACT, to pass the ponder-time to body) """ del losses inputs = common_layers.flatten4d3d(inputs) encoder_input, self_attention_bias, encoder_decoder_attention_bias = ( transformer.transformer_prepare_encoder( inputs, target_space, hparams, features=features)) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) (encoder_output, encoder_extra_output) = ( universal_transformer_util.universal_transformer_encoder( encoder_input, self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, "inputs"), save_weights_to=self.attention_weights)) return encoder_output, encoder_decoder_attention_bias, encoder_extra_output
def body(self, features): hparams = self._hparams inputs = features["inputs"] target_space = features["target_space_id"] inputs = common_layers.flatten4d3d(inputs) (encoder_input, encoder_self_attention_bias, _) = ( transformer_prepare_encoder(inputs, target_space, hparams)) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = transformer_encoder( encoder_input, encoder_self_attention_bias, hparams, nonpadding=features_to_nonpadding(features, "inputs")) encoder_output = tf.expand_dims(encoder_output, 2) return encoder_output
def body(self, features): """Transformer main model_fn. Args: features: Map of features to the model. Should contain the following: "inputs": Transformer inputs [batch_size, input_length, hidden_dim] "tragets": Target decoder outputs. [batch_size, decoder_length, hidden_dim] "target_space_id" Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] """ hparams = self._hparams if self.has_input: inputs = features["inputs"] target_space = features["target_space_id"] encoder_output, encoder_decoder_attention_bias = self.encode( inputs, target_space, hparams, features=features) else: encoder_output, encoder_decoder_attention_bias = (None, None) targets = features["targets"] targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_attention_bias = transformer_prepare_decoder( targets, hparams, features=features) decoder_output = self.decode( decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, nonpadding=features_to_nonpadding(features, "targets")) expected_attentions = features.get("expected_attentions") if expected_attentions is not None: attention_loss = common_attention.encoder_decoder_attention_loss( expected_attentions, self.attention_weights) return decoder_output, {"attention_loss": attention_loss} return decoder_output
def encode(self, features, input_key): hparams = self._hparams inputs = common_layers.flatten4d3d(features[input_key]) (encoder_input, encoder_self_attention_bias, _) = ( transformer.transformer_prepare_encoder(inputs, problem.SpaceID.EN_TOK, hparams)) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = transformer.transformer_encoder( encoder_input, encoder_self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, input_key)) encoder_output = tf.reduce_mean(encoder_output, axis=1) return encoder_output
def internal(self, features, real_features): """Main procedure for both training and inference.""" inputs = common_layers.flatten4d3d(features["inputs"]) targets = common_layers.flatten4d3d(features["targets"]) target_space = features["target_space_id"] hparams = self._hparams inputs_mask = ops.embedding_to_non_padding(inputs) inputs_length = tf.reduce_sum(inputs_mask, axis=-1) encoder_output, encoder_decoder_attention_bias = (ops.encoder( "encoder", hparams, inputs, target_space)) kwargs = { "encoder_output": encoder_output, "encoder_decoder_attention_bias": encoder_decoder_attention_bias } losses, monitor = {}, {} log_abs_det = tf.constant(0.0) if not self.is_predicting: # Training targets_mask = ops.embedding_to_non_padding(targets) targets_length = tf.reduce_sum(targets_mask, axis=-1) length_diff = targets_length - inputs_length decoder_self_attention_bias = ( common_attention.attention_bias_ignore_padding(1.0 - targets_mask)) z_q, log_q_z, q_dist = self.sample_q(targets, targets_mask, decoder_self_attention_bias, n_samples=1, temp=1.0, **kwargs) body_output = ops.decoder("decoder", z_q, hparams, decoder_self_attention_bias, **kwargs) logits = self.top(body_output, real_features) numerator, denominator = self.loss(logits, real_features) if not (self.is_evaluating and (hparams.compute_kl_refinement or hparams.compute_iw_marginal)): targets_length_pred, lenpred_loss = ops.predict_target_lengths( encoder_output, inputs_mask, hparams, length_diff) log_p_z_base, log_abs_det = self.compute_prior_log_prob( z_q, targets_mask, decoder_self_attention_bias, check_invertibility=False, **kwargs) losses, monitor = ops.save_log_loss( hparams, targets_mask, numerator, denominator, log_q_z, log_abs_det, log_p_z_base, z_q, lenpred_loss, targets_length_pred, targets_length) if self.is_evaluating: if hparams.compute_kl_refinement: z_p, _ = self.sample_p(targets_length, temp=self._decode_hparams.temp, check_invertibility=False, targets_mask=targets_mask, **kwargs) z_dq = self.delta_posterior( z_p, targets_mask, decoder_self_attention_bias, self._decode_hparams.n_gibbs_steps, **kwargs) log_q_z_ = q_dist.log_prob(z_dq) log_q_z_ = gops.reduce_mean_over_bl_sum_over_c( log_q_z_, targets_mask) losses = {"training": log_q_z_} if hparams.compute_iw_marginal: # if True: log_p_y_x = self.compute_iw_marginal( targets, targets_mask, decoder_self_attention_bias, real_features, self._decode_hparams.n_samples, **kwargs) # real_features, 1, **kwargs) losses = {"training": log_p_y_x} return logits, losses, monitor, targets_mask else: # Inference targets_length, _ = ops.predict_target_lengths( encoder_output, inputs_mask, hparams) targets_mask = ops.sequence_mask(targets_length, hparams) decoder_self_attention_bias = ( common_attention.attention_bias_ignore_padding(1.0 - targets_mask)) z_p, _ = self.sample_p(targets_length, temp=self._decode_hparams.temp, check_invertibility=False, **kwargs) z_q = self.delta_posterior(z_p, targets_mask, decoder_self_attention_bias, self._decode_hparams.n_gibbs_steps, **kwargs) # 0, **kwargs) body_output = ops.decoder("decoder", z_q, hparams, decoder_self_attention_bias, **kwargs) return body_output, losses, monitor, targets_mask
def ae_transformer_internal(inputs, targets, target_space, hparams, cache=None, predict_mask=1.0): """AE Transformer, main step used for training.""" # Summaries break with the do_refine cond, turn them off in that case. global _DO_SUMMARIES if hparams.do_refine: _DO_SUMMARIES = False # Prepare. if inputs is not None: batch_size = common_layers.shape_list(inputs)[0] else: batch_size = common_layers.shape_list(targets)[0] targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size]) # Encoder. if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs, target_space, hparams, "input_enc") inputs_ex, ed_ex = inputs, ed else: ed, inputs_ex, ed_ex = None, None, None # Autoencoding. losses = { "extra": tf.constant(0.0), "latent_pred": tf.constant(0.0), "neg_q_entropy": tf.constant(0.0) } if hparams.do_ae: # flatten here original_targets_shape = tf.shape(targets) if hparams.task == "image": cia.maybe_reshape_4d_to_3d(targets) if hparams.task == "translate": if inputs is not None: max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1) else: max_targets_len_from_inputs = targets else: assert hparams.task == "image" max_targets_len_from_inputs = targets if hparams.word_shuffle: tf.logging.info("Using word shuffle with rate = {}".format( hparams.word_shuffle)) targets_idx = tf.range(start=0, limit=common_layers.shape_list(targets)[1], delta=1) targets_idx = tf.to_float(targets_idx) noise = tf.random_uniform( shape=common_layers.shape_list(targets_idx), minval=0, maxval=1 + hparams.word_shuffle) targets_idx += noise permutation = tf.contrib.framework.argsort(targets_idx) targets_permuted = tf.gather(targets, indices=permutation, axis=1) targets = targets_permuted targets, _ = common_layers.pad_to_same_length( targets, max_targets_len_from_inputs, final_length_divisible_by=2**hparams.num_compress_steps) if hparams.word_dropout: mask = tf.random_uniform(shape=common_layers.shape_list(targets), minval=0.0, maxval=1.0) targets_noisy = tf.where(mask > hparams.word_dropout, targets, tf.zeros_like(targets)) else: targets_noisy = targets targets_c = compress(targets_noisy, inputs, False, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. latents_dense, latents_discrete, extra_loss, embed, neg_q_entropy = ( hparams.bottleneck(inputs=targets_c, filter_size=hparams.compress_filter_size, mode=hparams.mode, name="vc")) if _DO_SUMMARIES: tf.summary.histogram( "b0", tf.reshape(latents_discrete[:, 0, :], [-1])) pc = common_layers.inverse_exp_decay(hparams.startup_steps) pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([batch_size]), pc) latents_dense = tf.where(cond, latents_dense, targets_c) # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean. losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond)) # Extra loss predicting latent code from input. Discrete only. if hparams.bottleneck_kind not in ["dense", "vae"]: latents_pred = decode_transformer(inputs_ex, ed_ex, embed(latents_discrete), hparams, "extra", task="translate") _, latent_pred_loss = ae_latent_softmax( latents_pred, tf.stop_gradient(latents_discrete), hparams) # Scale by latent dimension for summary so we can compare across # batches. if _DO_SUMMARIES: tf.summary.scalar("latent_pred_loss_mean", tf.reduce_mean(latent_pred_loss)) if hparams.sum_over_latents: latent_pred_loss = tf.reduce_sum(latent_pred_loss, [1, 2]) losses["latent_pred"] = tf.reduce_mean( latent_pred_loss * tf.to_float(cond)) * hparams.prior_scale losses["neg_q_entropy"] = neg_q_entropy * hparams.entropy_scale else: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") losses["latent_pred"] = tf.reduce_mean( (inputs_c - targets_c)**2) * 20 def bn_inputs(): with tf.variable_scope(tf.get_variable_scope(), reuse=True): bn, _, _, _, _ = hparams.bottleneck( inputs=inputs_c, filter_size=hparams.compress_filter_size, mode=hparams.mode, name="vc") return bn inputs_c = bn_inputs() ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5 ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 latents_dense = tf.where( tf.less(tf.random_uniform([batch_size]), ptc), latents_dense, inputs_c) else: if hparams.bottleneck_kind in ["dense", "vae"]: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") latents_dense, _, _, _, _ = hparams.bottleneck( inputs=inputs_c, filter_size=hparams.compress_filter_size, mode=hparams.mode, name="vc") else: latent_len = common_layers.shape_list(targets_c)[1] _, _, _, embed, _ = hparams.bottleneck( inputs=targets_c, filter_size=hparams.compress_filter_size, name="vc") latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :]) if cache is None: cache = ae_latent_sample(latents_dense, inputs_ex, ed_ex, embed, 16, hparams) latents_dense = embed(cache) # Postprocess. d = latents_dense latent_len = common_layers.shape_list(latents_dense)[1] if isinstance(latent_len, tf.Tensor): # TODO(trandustin): Fix this in a better manner. latent_len = max(1000, hparams.max_length) pos = tf.get_variable("pos", [1, latent_len + 1, 1, hparams.hidden_size]) pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :] latents_dense = tf.pad(latents_dense, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos # decompressing the dense latents for i in range(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j) if inputs is not None and hparams.do_attend_decompress: d = attend(d, inputs, hparams, "decompress_attend_%d" % j) d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j) # Masking. if hparams.do_mask: masking = common_layers.inverse_lin_decay( hparams.mask_startup_steps) masking *= common_layers.inverse_exp_decay( hparams.mask_startup_steps // 4) # Not much at start. if not hparams.do_refine: masking -= tf.random_uniform([]) * hparams.unmasked_percentage masking = tf.minimum(tf.maximum(masking, 0.0), 1.0) if hparams.use_predict_mask: masking = predict_mask if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask mask = tf.less( masking, tf.random_uniform(common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 3) # targets is always [batch, length, 1, depth] targets = mask * targets + (1.0 - mask) * d # reshape back to 4d here if hparams.task == "image": targets = tf.reshape(targets, original_targets_shape) res = decode_transformer(inputs, ed, targets, hparams, "decoder", causal=hparams.causal) if hparams.do_ae: if hparams.do_mask and hparams.do_refine: def refine_res(): # return residual_conv(res, 1, (5, 1), hparams, "refine") r, _ = encode(tf.squeeze(res, axis=[2]), target_space, hparams, "refine_enc") return tf.expand_dims(r, axis=2) masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3]) all_masked = tf.less(masked_batches, 0.1) res = tf.where(all_masked, refine_res(), res) # We'll start training the extra model of latents after mask_startup_steps. nonlatent_steps = hparams.mask_startup_steps latent_time = tf.less(nonlatent_steps, tf.to_int32(tf.train.get_global_step())) losses["latent_pred"] *= tf.to_float(latent_time) # res was generated from padded targets, which means it has some extra # elements. These can cause shape problems when computing loss with respect to # the original (unpadded) targets. So we remove their extra elements here. res = res[:, :original_targets_shape[1], :, :] return res, losses, cache
def ae_transformer_internal(inputs, targets, target_space, hparams, cache=None, predict_mask=1.0, means=None, ema_count=None, ema_means=None): """AE Transformer, main step used for training.""" # Summaries break with the do_refine cond, turn them off in that case. global _DO_SUMMARIES if hparams.do_refine: _DO_SUMMARIES = False # Prepare. batch_size = common_layers.shape_list(inputs)[0] targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size]) # Encoder. if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs, target_space, hparams, "input_enc") else: ed = None # Autoencoding. losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)} if hparams.do_ae: max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1) targets, _ = common_layers.pad_to_same_length( targets, max_targets_len_from_inputs, final_length_divisible_by=2**hparams.num_compress_steps) targets_c = compress(targets, False, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. latents_dense, latents_discrete, extra_loss, _ = bottleneck( targets_c, hparams, 2 * 2048, "vc", means, ema_count, ema_means) if _DO_SUMMARIES: tf.summary.histogram("b0", tf.reshape(latents_discrete[:, 0, :], [-1])) pc = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.95 pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([batch_size]), pc) latents_dense = tf.where(cond, latents_dense, targets_c) # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean. losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond)) # Extra loss predicting latent code from input. Discrete only. if hparams.bottleneck_kind not in ["dense", "vae"]: latents_pred = decode_transformer( tf.stop_gradient(inputs), tf.stop_gradient(ed), tf.stop_gradient(latents_dense), hparams, "extra") latents_pred = tf.layers.dense(latents_pred, 2**16, name="extra_logits") losses["latent_pred"] = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=latents_discrete, logits=latents_pred) losses["latent_pred"] = tf.reduce_mean( losses["latent_pred"] * 0.5 * tf.to_float(cond)) else: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") losses["latent_pred"] = tf.reduce_mean((inputs_c - targets_c)**2) * 20 def bn_inputs(): with tf.variable_scope(tf.get_variable_scope(), reuse=True): bn, _, _, _ = bottleneck(inputs_c, hparams, 2 * 2048, "vc", means, ema_count, ema_means) return bn pbn = 0.8 if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 inputs_c = tf.cond(tf.less(tf.random_uniform([]), pbn), bn_inputs, lambda: inputs_c) ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5 ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 latents_dense = tf.where(tf.less(tf.random_uniform([batch_size]), ptc), latents_dense, inputs_c) else: if hparams.bottleneck_kind in ["dense", "vae"]: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") latents_dense, _, _, _ = bottleneck(inputs_c, hparams, 2 * 2048, "vc", means, ema_count, ema_means) else: latent_len = common_layers.shape_list(targets_c)[1] _, _, _, embed = bottleneck(targets_c, hparams, 2 * 2048, "vc", means, ema_count, ema_means) latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :]) if cache is None: cache = ae_latent_sample(latents_dense, inputs, ed, embed, 8, hparams) latents_dense = embed(cache) # Postprocess. d = latents_dense pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size]) pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :] latents_dense = tf.pad(latents_dense, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos # Masking. if hparams.do_mask: masking = common_layers.inverse_lin_decay(100000) masking *= common_layers.inverse_exp_decay(25000) # Not much at start. if not hparams.do_refine: masking -= tf.random_uniform([]) * 0.3 masking = tf.minimum(tf.maximum(masking, 0.0), 1.0) if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask mask = tf.less(masking, tf.random_uniform( common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 3) for i in xrange(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j) d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j) targets = mask * targets + (1.0 - mask) * d targets = tf.concat([tf.reverse(latents_dense, [1]), targets], axis=1) res = decode_transformer(inputs, ed, targets, hparams, "decoder") if hparams.do_ae: res = res[:, common_layers.shape_list(latents_dense)[1]:, :, :] if hparams.do_mask and hparams.do_refine: def refine_res(): return residual_conv(res, 1, (5, 1), hparams, "refine") masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3]) all_masked = tf.less(masked_batches, 0.1) res = tf.where(all_masked, refine_res(), res) # We'll start training only the extra model of latents after 400K steps. # Before we train only this, we decrease lr for other weights. latent_time = tf.less(300000, tf.to_int32(tf.train.get_global_step())) decreased_lr = common_layers.inverse_lin_decay(400000) losses["latent_pred"] *= tf.to_float(latent_time) losses["extra"] *= 1.0 - tf.to_float(latent_time) decreased_lr_res = tf.stop_gradient(decreased_lr * res) decreased_lr_res += (1.0 - decreased_lr) * res res = tf.cond(latent_time, lambda: decreased_lr_res, lambda: res) return res, losses, cache
def body(self, features): """Seq2Edits main model_fn. Args: features: Feature dictionary. Should contain the following fields: "inputs": [batch_size, input_length, 1, hidden_dim] float tensor with input token embeddings. "targets": [batch_size, target_length, 1, hidden_dim] float tensor with target token embeddings. "targets_error_tag": [batch_size, target_length, 1, hidden_dim] float tensor with target error tag embeddings. "target_space_id": A scalar int from data_generators.problem.SpaceID. Returns: Final decoder representation. Dictionary containing the following fields: "targets": [batch_size, target_length, hidden_dim] float tensor with decoder outputs "targets_error_tag": [batch_size, target_length, hidden_dim] float tensor with decoder outputs """ hparams = self._hparams losses = [] if self.has_input: target_space = features['target_space_id'] encoder_output, encoder_decoder_attention_bias = self.encode( features['inputs'], target_space, hparams, features=features, losses=losses, ) else: encoder_output, encoder_decoder_attention_bias = (None, None) targets = features['targets'] targets_shape = common_layers.shape_list(targets) targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_attention_bias = self._prepare_decoder_fn( targets, hparams, features=features) nonpadding = features_to_nonpadding(features, 'targets') # Add edit ops layer to condition on start_token, end_token, and error_tag decoder_input = transformer_edit_ops_layer( decoder_input, hparams, encoder_output, features, nonpadding=nonpadding, losses=losses, ) if hparams.middle_prediction: num_decoder_layers = (hparams.num_decoder_layers or hparams.num_hidden_layers) hparams.num_decoder_layers = int( num_decoder_layers / hparams.middle_prediction_layer_factor) decode_kwargs = {} decoder_output = self.decode(decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, nonpadding=nonpadding, losses=losses, **decode_kwargs) loss_mask = common_layers.weights_nonzero( maybe_flatten4d2d(features['targets_raw'])) self.loss_den = tf.reduce_sum(loss_mask) decoder_output = self._prediction_cascade( hparams=hparams, features=features, losses=losses, loss_mask=loss_mask, nonpadding=nonpadding, encoder_decoder_attention_bias=encoder_decoder_attention_bias, encoder_output=encoder_output, decoder_output=decoder_output, ) if hparams.middle_prediction: with tf.variable_scope('after_prediction'): decoder_output = self.decode(decoder_input + decoder_output, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, nonpadding=nonpadding, losses=losses, **decode_kwargs) ret = {'targets': tf.reshape(decoder_output, targets_shape)} ret.update(self.logits) if losses: return ret, {'extra_loss': tf.add_n(losses)} else: return ret
def render2cmd_v3_internal(self, features, hparams, train): # inputs and targets are both sequences with # shape = [batch, seq_len, 1, hparams.problem.feature_dim] targets = features['targets'] losses = {} sampled_bottleneck = self.pretrained_visual_encoder(features, hparams) if hparams.sg_bottleneck: sampled_bottleneck = tf.stop_gradient(sampled_bottleneck) with tf.variable_scope('render2cmd_v3_internal'): # override bottleneck, or return it, if requested if 'bottleneck' in features: if common_layers.shape_list(features['bottleneck'])[0] == 0: # return sampled_bottleneck, # set losses['training'] = 0 so self.top() doesn't get called on it return sampled_bottleneck, {'training': 0.0} else: # we want to use the given bottleneck sampled_bottleneck = features['bottleneck'] # finalize bottleneck unbottleneck_dim = hparams.hidden_size * 2 # twice because using LSTM if hparams.twice_decoder: unbottleneck_dim = unbottleneck_dim * 2 # unbottleneck back to LSTMStateTuple dec_initial_state = [] for hi in range(hparams.num_hidden_layers): unbottleneck = self.unbottleneck(sampled_bottleneck, unbottleneck_dim, name_append='_{}'.format(hi)) dec_initial_state.append( tf.nn.rnn_cell.LSTMStateTuple( c=unbottleneck[:, :unbottleneck_dim // 2], h=unbottleneck[:, unbottleneck_dim // 2:])) dec_initial_state = tuple(dec_initial_state) shifted_targets = common_layers.shift_right(targets) # Add 1 to account for the padding added to the left from shift_right targets_length = common_layers.length_from_embedding(shifted_targets) + 1 # LSTM decoder hparams_decoder = copy.copy(hparams) if hparams.twice_decoder: hparams_decoder.hidden_size = 2 * hparams.hidden_size if hparams.mode == tf.estimator.ModeKeys.PREDICT: decoder_outputs, _ = self.lstm_decoder_infer( common_layers.flatten4d3d(shifted_targets), targets_length, hparams_decoder, features['targets_cls'], train, initial_state=dec_initial_state, bottleneck=sampled_bottleneck) else: decoder_outputs, _ = self.lstm_decoder( common_layers.flatten4d3d(shifted_targets), targets_length, hparams_decoder, features['targets_cls'], train, initial_state=dec_initial_state, bottleneck=sampled_bottleneck) ret = tf.expand_dims(decoder_outputs, axis=2) return ret, losses
def adv_transformer_internal(inputs, targets, target_space, hparams): """Adversarial Transformer, main step used for training.""" with tf.variable_scope("adv_transformer"): batch_size = tf.shape(targets)[0] targets = tf.reshape(targets, [batch_size, -1, 1]) intermediate = tf.constant(34 * 1024 - 1) intermediate += tf.zeros_like(targets) targets = tf.concat([targets, intermediate], axis=2) targets = tf.reshape(targets, [batch_size, -1, 1]) embedding = tf.get_variable("embedding", [34 * 1024, hparams.hidden_size]) targets_emb = tf.gather(embedding, targets) # Noisy embedded targets. targets_noisy = tf.one_hot(targets, 34 * 1024) noise_val = hparams.noise_val targets_noisy += tf.random_uniform(tf.shape(targets_noisy), minval=-noise_val, maxval=noise_val) targets_emb_noisy = softmax_embed(targets_noisy, embedding, batch_size, hparams) # Encoder. if inputs is not None: inputs_emb = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs_emb, target_space, hparams, "input_enc") else: ed = None # Masking. masking = common_layers.inverse_lin_decay(200000) masking *= common_layers.inverse_exp_decay(50000) # Not much at start. masking -= tf.random_uniform([]) * 0.4 masking = tf.minimum(tf.maximum(masking, 0.0), 1.0) mask = tf.less(masking, tf.random_uniform(tf.shape(targets))) mask = tf.expand_dims(tf.to_float(mask), 3) noise = tf.random_uniform(tf.shape(targets_emb)) targets_emb = mask * targets_emb + (1.0 - mask) * noise # Decoder. res_dec = decode(inputs, ed, targets_emb, hparams, "decoder") res = tf.layers.dense(res_dec, 34 * 1024, name="res_sm") res_emb = softmax_embed(res, embedding, batch_size, hparams) # Extra steps. extra_step_prob = masking * 0.6 + 0.3 if hparams.mode != tf.estimator.ModeKeys.TRAIN: extra_step_prob = 1.0 for _ in xrange(hparams.extra_steps): def another_step(emb): res_dec = decode(inputs, ed, emb, hparams, "decoder", reuse=True) res = tf.layers.dense(res_dec, 34 * 1024, name="res_sm", reuse=True) return softmax_embed(res, embedding, batch_size, hparams), res res_emb, res = tf.cond(tf.less(tf.random_uniform([]), extra_step_prob), lambda e=res_emb: another_step(e), lambda: (res_emb, res)) # Adversary. delta = masking * hparams.delta_max true_logit = adversary(tf.stop_gradient(targets_emb_noisy), tf.stop_gradient(inputs + inputs_emb), hparams, "adversary") gen_logit = adversary(reverse_gradient(res_emb, delta), tf.stop_gradient(inputs + inputs_emb), hparams, "adversary", reuse=True) losses = {"adv": gen_logit - true_logit} res = tf.stop_gradient(masking * res) + (1.0 - masking) * res return res, losses
def render2cmd_v3_internal(self, features, hparams, train): # inputs and targets are both sequences with # shape = [batch, seq_len, 1, hparams.problem.feature_dim] print( "render2cmd_v3_internal render2cmd_v3_internalrender2cmd_v3_internalrender2cmd_v3_internalrender2cmd_v3_internal" ) all_targets = features['targets'] all_targets_cls = features['targets_cls'] all_targets_font_cls = features['targets_fnt'] all_targets_psr = features['targets_psr'] all_batch_size = common_layers.shape_list(all_targets)[0] batch_size = all_batch_size // 2 sources = all_targets[:batch_size, ...] sources_cls = all_targets_cls[:batch_size, ...] sources_fnt = all_targets_font_cls[:batch_size, ...] sources_psr = all_targets_psr[:batch_size, ...] targets = all_targets[batch_size:, ...] targets_cls = all_targets_cls[batch_size:, ...] targets_fnt = all_targets_font_cls[batch_size:, ...] targets_psr = all_targets_psr[batch_size:, ...] losses = {} # sampled_bottleneck = self.pretrained_visual_encoder(features, hparams) # if hparams.sg_bottleneck: # sampled_bottleneck = tf.stop_gradient(sampled_bottleneck) # embd = self.cls_embedding(sources_cls, sources_fnt, targets_cls, targets_fnt) vis_embd = self.vis_encoder(sources_psr, targets_psr, targets_cls) # print("embd embd embd embd embd embd embd ", embd.shape) print("vis embd vis embd vis embd vis embd vis", vis_embd.shape) sampled_bottleneck = vis_embd with tf.variable_scope('render2cmd_v3_internal'): # override bottleneck, or return it, if requested # if 'bottleneck' in features: # if common_layers.shape_list(features['bottleneck'])[0] == 0: # # return sampled_bottleneck, # # set losses['training'] = 0 so self.top() doesn't get called on it # print("RETURNRETURNRETURNRETURNRETURNRETURNRETURNRETURNRETURNRETURNRETURN") # return sampled_bottleneck, {'training': 0.0} # else: # # we want to use the given bottleneck # sampled_bottleneck = features['bottleneck'] # finalize bottleneck unbottleneck_dim = hparams.hidden_size * 2 # twice because using LSTM if hparams.twice_decoder: unbottleneck_dim = unbottleneck_dim * 2 dec_initial_state = [] # LSTM encoder _, encoder_output_states = self.lstm_encoder( common_layers.flatten4d3d(sources), hparams) print( "targets shape targets shape targets shape targets shape targets shape ", targets.shape) print('run stacking...') print( "sample bottleneck shape sample bottleneck shape sample bottleneck shape ", sampled_bottleneck.shape) print( "sources shape sources shape sources shape sources shape sources shape", sources.shape) # input() for hi in range(hparams.num_hidden_layers): unbottleneck = self.unbottleneck(sampled_bottleneck, unbottleneck_dim, name_append='_{}'.format(hi)) c, h = encoder_output_states[hi] # print(unbottleneck.shape) # print(c.shape, h.shape) # first_dim = common_layers.shape_list(unbottleneck)[0] # print(first_dim) # c = tf.tile(c,[first_dim,1]) # h = tf.tile(h,[first_dim,1]) # input() dec_initial_state.append( tf.nn.rnn_cell.LSTMStateTuple( c=tf.concat( [unbottleneck[:, :unbottleneck_dim // 2], c], 1), h=tf.concat( [unbottleneck[:, unbottleneck_dim // 2:], h], 1))) dec_initial_state = tuple(dec_initial_state) # print('checkshape dec_initial_state') # print(dec_initial_state) # input() shifted_targets = common_layers.shift_right(targets) # Add 1 to account for the padding added to the left from shift_right targets_length = common_layers.length_from_embedding( shifted_targets) + 1 # LSTM decoder hparams_decoder = copy.copy(hparams) if hparams.twice_decoder: hparams_decoder.hidden_size = 2 * hparams.hidden_size if hparams.mode == tf.estimator.ModeKeys.PREDICT: decoder_outputs, _ = self.lstm_decoder_infer( common_layers.flatten4d3d(shifted_targets), targets_length, hparams_decoder, targets_cls, train, initial_state=dec_initial_state, bottleneck=sampled_bottleneck) else: decoder_outputs, _ = self.lstm_decoder( common_layers.flatten4d3d(shifted_targets), targets_length, hparams_decoder, targets_cls, train, initial_state=dec_initial_state, bottleneck=sampled_bottleneck) ret = tf.expand_dims(decoder_outputs, axis=2) return ret, losses
def transformer_autoencoder(inputs, targets, target_space, hparams, cache=None, predict_mask=1.0): """AE Transformer, main step used for training.""" # Define losses losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)} # Reshape image targets as 4d tensor. original_targets_shape = common_layers.shape_list(targets) if len(original_targets_shape) == 4: compress_fn = compress_encoder_2d decompress_fn = decompress_decoder_2d else: compress_fn = compress_encoder_1d decompress_fn = decompress_decoder_1d # Encoder decoder attention bias. ed_attention_bias = None # Input Encoder if present. if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs, ed_attention_bias = transformer_text_encoder( inputs, target_space, hparams, "input_enc") # Encode targets to compute targets compressed. targets_c = compress_fn(targets, hparams, "compress") targets, _, _ = cia.maybe_reshape_4d_to_3d(targets) # Following code creates an exponentially decaying variable based on which # we rescale the los values. batch_size = common_layers.shape_list(targets_c)[0] pc = common_layers.inverse_exp_decay(hparams.startup_steps) pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([batch_size]), pc) # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean. # Call bottleneck layer to get the latents. # Returns embedded latents, discrete latents, loss and the embedding function. latents_dense, latents_discrete, extra_loss, embed = (bottleneck_layer( targets_c, hparams)) extra_loss = tf.reduce_mean(extra_loss) * tf.to_float(cond) # Call the autoregressive latent prediction model. _, latents_pred_loss = latent_prediction_model(targets_c, ed_attention_bias, latents_discrete, embed, hparams, name="latent_pred") latents_pred_loss = tf.reduce_mean(latents_pred_loss) * tf.to_float(cond) # Assign latent loss losses["latent_pred"] = latents_pred_loss losses["extra_loss"] = extra_loss latents_decoder = latents_dense if len(original_targets_shape) == 4: cmp_img_len = hparams.img_len / (2**(hparams.num_compress_steps // 2)) latents_decoder = tf.reshape(latents_decoder, [ batch_size, cmp_img_len, cmp_img_len, hparams.num_latents * hparams.hidden_size ]) # Decompress either using 1D or 2D upconvs. latents_decoder = decompress_fn(latents_decoder, hparams, name="decompress") # if we're operating in 2d space on images, then we're assuming that the # last dimension will not be a multiple of channels latents_decoder = tf.reshape( latents_decoder, shape=[-1, hparams.img_len, hparams.img_len, hparams.hidden_size]) if hparams.use_gold_targets: latents_decoder, _, _ = cia.maybe_reshape_4d_to_3d(latents_decoder) masking = common_layers.inverse_exp_decay(hparams.mask_startup_steps) if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask mask = tf.less( masking, tf.random_uniform(common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 2) targets = mask * targets + (1.0 - mask) * latents_decoder else: targets = latents_decoder # reshape back to 4d here targets = tf.reshape(targets, original_targets_shape) if hparams.decode_autoregressive: # Transformer decoder, that goes from inputs->targets res = transformer_image_decoder(inputs, ed_attention_bias, targets, hparams, "decoder") else: res = targets # We'll start training the extra model of latents after mask_startup_steps. latent_time = tf.less(hparams.mask_startup_steps, tf.to_int32(tf.train.get_global_step())) losses["latent_pred"] *= tf.to_float(latent_time) return res, losses, cache
def body(self, features): hp = self.hparams # pylint: disable=eval-used if hp.image_input_type == "image": image_feat = vqa_layers.image_embedding( features["inputs"], model_fn=eval(hp.image_model_fn), trainable=hp.train_resnet, is_training=hp.mode == tf.estimator.ModeKeys.TRAIN) else: image_feat = features["inputs"] image_feat = common_layers.flatten4d3d(image_feat) image_hidden_size = hp.image_hidden_size or hp.hidden_size if hp.image_feat_preprocess_proj: image_feat = common_layers.dense(image_feat, image_hidden_size) utils.collect_named_outputs("norms", "image_feat_after_proj", tf.norm(image_feat, axis=-1)) else: assert image_hidden_size == 2048 image_feat = tf.nn.dropout(image_feat, keep_prob=1. - hp.layer_prepostprocess_dropout) if hp.image_feat_encode: image_feat = image_encoder(image_feat, hp) utils.collect_named_outputs("norms", "image_feat_encoded", tf.norm(image_feat, axis=-1)) else: image_feat = common_layers.layer_norm(image_feat) utils.collect_named_outputs("norms", "image_feat_after_layer", tf.norm(image_feat, axis=-1)) question = common_layers.flatten4d3d(features["question"]) utils.collect_named_outputs("norms", "question_embedding", tf.norm(question, axis=-1)) question, question_self_attention_bias = prepare_question_encoder( question, hp) question = tf.nn.dropout(question, keep_prob=1. - hp.layer_prepostprocess_dropout) query = question_encoder(question, question_self_attention_bias, hp) utils.collect_named_outputs("norms", "query_encode", tf.norm(query, axis=-1)) query = (query + tf.expand_dims( tf.squeeze(question_self_attention_bias, [1, 2]), axis=2)) query = tf.reduce_max(query, axis=1) utils.collect_named_outputs("norms", "query_maxpool", tf.norm(query, axis=-1)) # query = common_layers.l2_norm(query) # utils.collect_named_outputs("norms", "query_after_l2", # tf.norm(query, axis=-1)) image_ave = attn(image_feat, query, hp) utils.collect_named_outputs("norms", "image_ave", tf.norm(image_ave, axis=-1)) if hp.multimodal_combine == "concat": image_question = tf.concat([image_ave, query], axis=1) elif hp.multimodal_combine == "sum": image_question = image_ave + query elif hp.multimodal_combine == "product": image_question = image_ave * query utils.collect_named_outputs("norms", "image_question", tf.norm(image_question, axis=-1)) image_question = tf.nn.dropout(image_question, 1. - hp.dropout) output = mlp(image_question, hp) utils.collect_named_outputs("norms", "output", tf.norm(output, axis=-1)) norm_tensors = utils.convert_collection_to_dict("norms") vqa_layers.summarize_tensors(norm_tensors, tag="norms/") # Expand dimension 1 and 2 return tf.expand_dims(tf.expand_dims(output, axis=1), axis=2)
def baseline_perf_transformer_encode(encoder_function, inputs, target_space, hparams, attention_weights=None, features=None, losses=None, prepare_encoder_fn=None, **kwargs): """Encoding for baseline performance transformer, no mean-aggregation. Args: encoder_function: the encoder function inputs: Transformer inputs [batch_size, input_length, 1, hidden_dim] which will be flattened along the two spatial dimensions. target_space: scalar, target space ID. hparams: hyperparameters for model. attention_weights: weight to store attention to. features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. losses: optional list onto which to append extra training losses prepare_encoder_fn: optional, alternative to transformer_prepare_encoder. **kwargs: additional arguments to pass to encoder_function Returns: Tuple of: encoder_output: Encoder representation. [batch_size, input_length, hidden_dim] encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder attention. [batch_size, input_length] """ inputs = common_layers.flatten4d3d(inputs) if not prepare_encoder_fn: prepare_encoder_fn = transformer_prepare_encoder encoder_input, self_attention_bias, encoder_decoder_attention_bias = ( prepare_encoder_fn(inputs, target_space, hparams, features=features, reuse_target_embedding=tf.AUTO_REUSE)) mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT, value=hparams.layer_prepostprocess_dropout, hparams=hparams) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) attn_bias_for_padding = None # Otherwise the encoder will just use encoder_self_attention_bias. if hparams.unidirectional_encoder: attn_bias_for_padding = encoder_decoder_attention_bias encoder_output = encoder_function( encoder_input, self_attention_bias, hparams, name="encoder", nonpadding=features_to_nonpadding(features, "inputs"), save_weights_to=attention_weights, make_image_summary=not common_layers.is_xla_compiled(), losses=losses, attn_bias_for_padding=attn_bias_for_padding, **kwargs) # no aggregation --> just return everything normally return encoder_output, encoder_decoder_attention_bias
def transformer_autoencoder(inputs, targets, target_space, hparams, cache=None, predict_mask=1.0): """Auto-encoder using transformer decoder and prior over latents.""" # Define losses losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)} # Reshape image targets as 4d tensor. original_targets_shape = common_layers.shape_list(targets) batch_size = original_targets_shape[0] if len(original_targets_shape) == 4: compress_fn = compress_encoder_2d decompress_fn = decompress_decoder_2d else: compress_fn = compress_encoder_1d decompress_fn = decompress_decoder_1d # Encoder decoder attention bias. ed_attention_bias = None # Input Encoder if present. if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs, ed_attention_bias = transformer_text_encoder( inputs, target_space, hparams, "input_enc") # Encode targets to compute targets compressed. targets_c = compress_fn(targets, hparams, "compress") targets, _, _ = cia.maybe_reshape_4d_to_3d(targets) # Following code creates an exponentially decaying variable based on which # we rescale the loss values. pc = common_layers.inverse_exp_decay(hparams.startup_steps) pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([batch_size]), pc) # Call bottleneck layer, that takes encoder output and outputs the latents. # Returns embedded latents, discrete latent codes, loss. if hparams.mode != tf.estimator.ModeKeys.PREDICT: latents_dense, latents_discrete, extra_loss = (bottleneck_layer( targets_c, hparams)) extra_loss = tf.reduce_mean(extra_loss) * tf.to_float(cond) # Call the autoregressive latent prediction model. _, latents_pred_loss = latent_prediction_model(inputs, ed_attention_bias, latents_discrete, latents_dense, hparams, name="latent_pred") latents_pred_loss = tf.reduce_mean(latents_pred_loss) * tf.to_float( cond) # Latent dropout. latents_shape = common_layers.shape_list(latents_dense) latents_dense = tf.nn.dropout( latents_dense, 1 - hparams.latent_dropout, noise_shape=[latents_shape[0], latents_shape[1], 1]) # Assign latent loss. losses["latent_pred"] = latents_pred_loss losses["extra_loss"] = extra_loss else: latent_len = (hparams.img_len * hparams.img_len * hparams.num_latents) / 2**(hparams.num_compress_steps) embed = functools.partial(discretization.parametrized_unbottleneck, hparams=hparams) latents_dense = tf.zeros( [batch_size, latent_len, 1, hparams.hidden_size]) if cache is None: cache = ae_latent_sample_beam(latents_dense, inputs, ed_attention_bias, embed, hparams) latents_dense = embed( tf.one_hot(cache, depth=2**hparams.bottleneck_bits), hparams.hidden_size) latents_decoder = latents_dense if len(original_targets_shape) == 4: cmp_img_len = hparams.img_len / (2**(hparams.num_compress_steps // 2)) latents_decoder = tf.reshape(latents_decoder, [ batch_size, cmp_img_len, cmp_img_len, hparams.num_latents * hparams.hidden_size ]) # Decompress either using 1D or 2D upconvs. latents_decoder = decompress_fn(latents_decoder, hparams, name="decompress") # if we're operating in 2d space on images, then we're assuming that the # last dimension will not be a multiple of channels output = tf.reshape( latents_decoder, shape=[-1, hparams.img_len, hparams.img_len, hparams.hidden_size]) if hparams.use_gold_targets: latents_decoder, _, _ = cia.maybe_reshape_4d_to_3d(latents_decoder) masking = common_layers.inverse_exp_decay(hparams.mask_startup_steps) if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask mask = tf.less( masking, tf.random_uniform(common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 2) output = mask * targets + (1.0 - mask) * output # reshape back to 4d here output = tf.reshape(output, original_targets_shape) if hparams.decode_autoregressive: # Transformer decoder, that goes from inputs->targets decoder_output = transformer_image_decoder(output, inputs, ed_attention_bias, hparams, "decoder") else: decoder_output = output # We'll start training the extra model of latents after mask_startup_steps. latent_time = tf.less(hparams.mask_startup_steps, tf.to_int32(tf.train.get_global_step())) losses["latent_pred"] *= tf.to_float(latent_time) return decoder_output, losses, cache
def body(self, features): """Transformer main model_fn. Args: features: Map of features to the model. Should contain the following: "inputs": Transformer inputs. [batch_size, input_length, 1, hidden_dim]. "targets": Target decoder outputs. [batch_size, decoder_length, 1, hidden_dim] "target_space_id": A scalar int from data_generators.problem.SpaceID. Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] """ hparams = self._hparams losses = [] if self.has_input: # use melody-only as input features inputs = features["melody"] target_space = features["target_space_id"] encoder_output, encoder_decoder_attention_bias = self.encode( inputs, target_space, hparams, features=features, losses=losses) else: encoder_output, encoder_decoder_attention_bias = (None, None) targets = features["targets"] targets_shape = common_layers.shape_list(targets) targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_attention_bias = self._prepare_decoder_fn( targets, hparams, features=features) # Not all subclasses of Transformer support keyword arguments related to # recurrent memory, so only pass these arguments if memory is enabled. decode_kwargs = {} if self.recurrent_memory_by_layer is not None: # TODO(kitaev): The chunk_number feature currently has the same shape as # "targets", but this is only for the purposes of sharing sharding code. # In fact every token within an example must have the same chunk number. chunk_number_each_token = tf.squeeze(features["chunk_number"], (-1, -2)) chunk_number_each_example = chunk_number_each_token[:, 0] # Uncomment the code below to verify that tokens within a batch share the # same chunk number: # with tf.control_dependencies([ # tf.assert_equal(chunk_number_each_token, # chunk_number_each_example[:, None]) # ]): # chunk_number_each_example = tf.identity(chunk_number_each_example) decode_kwargs = dict( recurrent_memory_by_layer=self.recurrent_memory_by_layer, chunk_number=chunk_number_each_example, ) decoder_output = six.ensure_text(self, decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, nonpadding=features_to_nonpadding( features, "targets"), losses=losses, **decode_kwargs) expected_attentions = features.get("expected_attentions") if expected_attentions is not None: attention_loss = common_attention.encoder_decoder_attention_loss( expected_attentions, self.attention_weights, hparams.expected_attention_loss_type, hparams.expected_attention_loss_multiplier) return decoder_output, {"attention_loss": attention_loss} ret = tf.reshape(decoder_output, targets_shape) if losses: return ret, {"extra_loss": tf.add_n(losses)} else: return ret
def main(): FLAGS = Args() # Enable TF Eager execution tfe = tf.contrib.eager tfe.enable_eager_execution() # sample sentence input_str = 'Twas brillig, and the slithy toves Did gyre and gimble in the wade; All mimsy were the borogoves, And the mome raths outgrabe.' # convert sentence into index in vocab wmt_problem = problems.problem(FLAGS.problem) encoders = wmt_problem.feature_encoders(FLAGS.data_dir) inputs = encoders["inputs"].encode(input_str) + [1] # add EOS id batch_inputs = tf.reshape(inputs, [1, -1, 1]) # Make it 3D. features = {"inputs": batch_inputs} # initialize translation model hparams_set = FLAGS.hparams_set Modes = tf.estimator.ModeKeys hparams = trainer_lib.create_hparams(hparams_set, data_dir=FLAGS.data_dir, problem_name=FLAGS.problem) translate_model = registry.model(FLAGS.model)(hparams, Modes.EVAL) # recover parameters and conduct recurrent conduction ckpt_dir = tf.train.latest_checkpoint(FLAGS.model_dir) with tfe.restore_variables_on_create(ckpt_dir): with variable_scope.EagerVariableStore().as_default(): with tf.variable_scope('universal_transformer'): # Convert word index to word embedding features = translate_model.bottom(features) with tf.variable_scope('universal_transformer/body'): input_tensor = tf.convert_to_tensor(features['inputs']) input_tensor = common_layers.flatten4d3d(input_tensor) encoder_input, self_attention_bias, _ = ( transformer.transformer_prepare_encoder( input_tensor, tf.convert_to_tensor([0]), translate_model.hparams, features=None)) with tf.variable_scope('universal_transformer/body/encoder'): ffn_unit = functools.partial( universal_transformer_util.transformer_encoder_ffn_unit, hparams=translate_model.hparams) attention_unit = functools.partial( universal_transformer_util.transformer_encoder_attention_unit, hparams=translate_model.hparams, encoder_self_attention_bias=None, attention_dropout_broadcast_dims=[], save_weights_to={}, make_image_summary=True) storing_list = [] transformed_state = encoder_input for step_index in range(1024): storing_list.append(transformed_state.numpy()) with tf.variable_scope('universal_transformer/body/encoder/universal_transformer_{}'.format(FLAGS.ut_type)): transformed_state = universal_transformer_util.step_preprocess( transformed_state, tf.convert_to_tensor(step_index % FLAGS.step_num), translate_model.hparams ) with tf.variable_scope('universal_transformer/body/encoder/universal_transformer_{}/rec_layer_0'.format(FLAGS.ut_type)): transformed_new_state = ffn_unit(attention_unit(transformed_state)) with tf.variable_scope('universal_transformer/body/encoder'): if (step_index + 1) % FLAGS.step_num == 0: transformed_new_state = common_layers.layer_preprocess(transformed_new_state, translate_model.hparams) if step_index == 5: print(transformed_new_state) transformed_state = transformed_new_state storing_list = np.asarray(storing_list) np.save(FLAGS.save_dir, storing_list)
def body(self, features): hp = self.hparams # pylint: disable=eval-used if hp.image_input_type == "image": image_feat = vqa_layers.image_embedding( features["inputs"], model_fn=eval(hp.image_model_fn), trainable=hp.train_resnet, is_training=hp.mode == tf.estimator.ModeKeys.TRAIN) else: image_feat = features["inputs"] image_feat = common_layers.flatten4d3d(image_feat) image_feat = common_layers.dense(image_feat, hp.hidden_size) utils.collect_named_outputs("norms", "image_feat_after_proj", tf.norm(image_feat, axis=-1)) question = common_layers.flatten4d3d(features["question"]) utils.collect_named_outputs("norms", "question_embedding", tf.norm(question, axis=-1)) (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias) = prepare_image_question_encoder( image_feat, question, hp) encoder_input = tf.nn.dropout(encoder_input, keep_prob=1. - hp.layer_prepostprocess_dropout) encoder_output, _ = recurrent_transformer_decoder( encoder_input, None, encoder_self_attention_bias, None, hp, name="encoder") utils.collect_named_outputs("norms", "encoder_output", tf.norm(encoder_output, axis=-1)) # scale query by sqrt(hidden_size) query = tf.get_variable("query", [hp.hidden_size]) * hp.hidden_size**0.5 query = tf.expand_dims(tf.expand_dims(query, axis=0), axis=0) batch_size = common_layers.shape_list(encoder_input)[0] query = tf.tile(query, [batch_size, 1, 1]) query = tf.nn.dropout(query, keep_prob=1. - hp.layer_prepostprocess_dropout) decoder_output, _ = recurrent_transformer_decoder( query, encoder_output, None, encoder_decoder_attention_bias, hp, name="decoder") utils.collect_named_outputs("norms", "decoder_output", tf.norm(decoder_output, axis=-1)) norm_tensors = utils.convert_collection_to_dict("norms") vqa_layers.summarize_tensors(norm_tensors, tag="norms/") # Expand dimension 1 and 2 return tf.expand_dims(decoder_output, axis=1)
def encode(self, inputs_context, inputs, target_space, hparams, features=None, losses=None): """Encode transformer inputs. Args: inputs_context: contextual input [batch_size, input_length, hidden_dim] inputs: Transformer inputs [batch_size, input_length, input_height, hidden_dim] which will be flattened along the two spatial dimensions. target_space: scalar, target space ID. hparams: hyperparameters for model. features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. losses: optional list onto which to append extra training losses Returns: Tuple of: encoder_output: Encoder representation. [batch_size, input_length, hidden_dim] encoder_output_context: Contextual encoder representation [batch_size, input_length, hidden_dim] encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder attention. [batch_size, input_length] """ inputs = common_layers.flatten4d3d(inputs) encoder_input, self_attention_bias, encoder_decoder_attention_bias = ( transformer_prepare_encoder( inputs, target_space, hparams, features=features)) encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = transformer_encoder( encoder_input, self_attention_bias, hparams, nonpadding=features_to_nonpadding(features, "inputs"), save_weights_to=self.attention_weights, losses=losses) if inputs_context is None: return None, encoder_output, encoder_decoder_attention_bias inputs_context = common_layers.flatten4d3d(inputs_context) encoder_input_context, self_attention_bias_context, encoder_decoder_attention_bias_context = ( transformer_prepare_encoder( inputs_context, target_space, hparams, features=features)) encoder_input_context = tf.nn.dropout(encoder_input_context, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output_context_0 = transformer_encoder( encoder_input_context, self_attention_bias_context, hparams, name="encoder_context", nonpadding=features_to_nonpadding(features, "inputs"), save_weights_to=self.attention_weights, losses=losses) encoder_output_context = transformer_decoder( encoder_input, encoder_output_context_0, encoder_decoder_attention_bias, encoder_decoder_attention_bias_context, hparams, name="decoder_input_context") return encoder_output_context, encoder_output, encoder_decoder_attention_bias
def infer_step(logits_so_far, current_hidden): """Inference step of LSTM while loop.""" # unflatten hidden: current_hidden = tuple( tf.nn.rnn_cell.LSTMStateTuple(c=s[0], h=s[1]) for s in current_hidden) # put logits_so_far through top tm = self._problem_hparams.modality['targets'] # need to reuse top params reset_scope = tf.variable_scope(tf.VariableScope( tf.AUTO_REUSE, ''), reuse=tf.AUTO_REUSE, auxiliary_name_scope=False) top_scope = tf.variable_scope('svg_decoder/{}_modality'.format(tm), reuse=tf.AUTO_REUSE) with reset_scope, top_scope: samples_so_far = self.hparams.top['targets']( logits_so_far, None, self.hparams, self.problem_hparams.vocab_size) # append a zero pad to the samples. this effectively shifts the samples # right, but, unlike shift_right, by not removing the last element, we # allow an empty samples_so_far to not be empty after padding samples_so_far = tf.concat([zero_pad, samples_so_far], axis=1) shifted_targets = common_layers.flatten4d3d(samples_so_far) # now take the very last one here, will be the actual input to the rnn shifted_targets = shifted_targets[:, -1:, :] # tile and append the bottleneck to inputs sln_offset = 0 if hparams.condition_on_sln: sln_offset = 51 pre_tile_y = tf.reshape(bottleneck, [ common_layers.shape_list(bottleneck)[0], 1, hparams.bottleneck_bits + hparams.num_categories + sln_offset ]) overlay_x = tf.tile( pre_tile_y, [1, common_layers.shape_list(shifted_targets)[1], 1]) inputs = tf.concat([shifted_targets, overlay_x], -1) seq_len_batch = tf.ones([common_layers.shape_list(inputs)[0]]) # RUN PRE-LSTM LAYER with tf.variable_scope('pre_decoder', reuse=tf.AUTO_REUSE): inputs = tf.layers.dense(inputs, hparams.hidden_size, name='bottom') inputs = tf.nn.tanh(inputs) # RUN LSTM with tf.variable_scope('lstm_decoder', reuse=tf.AUTO_REUSE): next_step, next_state = tf.nn.dynamic_rnn( layers, inputs, seq_len_batch, initial_state=current_hidden, dtype=tf.float32, time_major=False) next_step = tf.expand_dims(next_step, [1]) logits_so_far = tf.concat([logits_so_far, next_step], 1) # flatten state next_state = tuple((s.c, s.h) for s in next_state) return logits_so_far, next_state
def body(self, features): """Transformer main model_fn. Args: features: Map of features to the model. Should contain the following: "inputs": Transformer inputs [batch_size, input_length, hidden_dim] "targets": Target decoder outputs. [batch_size, decoder_length, hidden_dim] "target_space_id": A scalar int from data_generators.problem.SpaceID. Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] """ hparams = self._hparams losses = [] if self.has_input: inputs_context = features.get("inputs_context") inputs = features["inputs"] target_space = features["target_space_id"] encoder_output_context, encoder_output, encoder_decoder_attention_bias = self.encode( inputs_context, inputs, target_space, hparams, features=features, losses=losses) else: encoder_output_context, encoder_output, encoder_decoder_attention_bias = (None, None, None) targets = features["targets"] targets_shape = common_layers.shape_list(targets) targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_attention_bias = transformer_prepare_decoder( targets, hparams, features=features) decoder_output = self.decode( decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, name="decoder_output_input", nonpadding=features_to_nonpadding(features, "targets"), losses=losses) if encoder_output_context is not None: decoder_output_context = self.decode( decoder_input, encoder_output_context, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, name="decoder_output_input_context", nonpadding=features_to_nonpadding(features, "targets"), losses=losses) decoder_output = self.cat_and_compress(decoder_output_context, decoder_output, hparams) expected_attentions = features.get("expected_attentions") if expected_attentions is not None: attention_loss = common_attention.encoder_decoder_attention_loss( expected_attentions, self.attention_weights, hparams.expected_attention_loss_type, hparams.expected_attention_loss_multiplier) return decoder_output, {"attention_loss": attention_loss} ret = tf.reshape(decoder_output, targets_shape) if losses: return ret, {"extra_loss": tf.add_n(losses)} else: return ret
def flatten(inputs): return tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2)
def encode(self, inputs, target_space, hparams, features=None, losses=None, **kwargs): """Encode Universal Transformer inputs. It is similar to "transformer.encode", but it uses "universal_transformer_util.universal_transformer_encoder" instead of "transformer.transformer_encoder". Args: inputs: Transformer inputs [batch_size, input_length, input_height, hidden_dim] which will be flattened along the two spatial dimensions. target_space: scalar, target space ID. hparams: hyperparmeters for model. features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. losses: Unused. **kwargs: additional arguments to pass to encoder_function Returns: Tuple of: encoder_output: Encoder representation. [batch_size, input_length, hidden_dim] encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder attention. [batch_size, input_length] encoder_extra_output: which is extra encoder output used in some variants of the model (e.g. in ACT, to pass the ponder-time to body) """ #### ## DEBUG #### # with open("invertible_UT_params.json", "w") as f: # json.dump(dict(hparams.__dict__), f, default=lambda o: '<not serializable>', sort_keys=True, # indent=4, separators=(',', ': ')) # sys.exit() del losses inputs = common_layers.flatten4d3d(inputs) encoder_input, self_attention_bias, encoder_decoder_attention_bias = ( transformer.transformer_prepare_encoder(inputs, target_space, hparams, features=features)) encoder_input = tf.nn.dropout( encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) (encoder_output, encoder_extra_output) = (invertible_UT_encoder( encoder_input, self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, "inputs"), save_weights_to=self.attention_weights)) for var in tf.trainable_variables(): print(var) return encoder_output, encoder_decoder_attention_bias, encoder_extra_output
def testFlatten4D3D(self): x = np.random.random_integers(1, high=8, size=(3, 5, 2)) y = common_layers.flatten4d3d(common_layers.embedding(x, 10, 7)) self.evaluate(tf.global_variables_initializer()) res = self.evaluate(y) self.assertEqual(res.shape, (3, 5 * 2, 7))
def ae_transformer_internal(inputs, targets, target_space, hparams, cache=None, predict_mask=1.0): """AE Transformer, main step used for training.""" # Summaries break with the do_refine cond, turn them off in that case. global _DO_SUMMARIES if hparams.do_refine: _DO_SUMMARIES = False # Change hyperparameters for the latent prediction model. hparams_ex = copy.copy(hparams) hparams_ex.filter_size *= 2 hparams_ex.hidden_size *= 2 hparams_ex.dropout = 0.0 hparams_ex.relu_dropout = 0.0 hparams_ex.z_dropout = 0.0 hparams_ex.layer_prepostprocess_dropout = 0.0 hparams_ex.symbol_dropout = 0.0 hparams.ex = hparams_ex # Prepare. if inputs is not None: batch_size = common_layers.shape_list(inputs)[0] else: batch_size = common_layers.shape_list(targets)[0] targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size]) # Encoder. if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs_ex = tf.layers.dense(tf.stop_gradient(inputs), hparams_ex.hidden_size, name="extra_embed") inputs, ed = encode(inputs, target_space, hparams, "input_enc") inputs_ex, ed_ex = encode(inputs_ex, target_space, hparams_ex, "extra_ienc") else: ed, inputs_ex, ed_ex = None, None, None # Autoencoding. losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)} if hparams.do_ae: # flatten here original_targets_shape = tf.shape(targets) if hparams.task == "image": cia.maybe_reshape_4d_to_3d(targets) if hparams.task == "translate": max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1) else: assert hparams.task == "image" max_targets_len_from_inputs = targets targets, _ = common_layers.pad_to_same_length( targets, max_targets_len_from_inputs, final_length_divisible_by=2**hparams.num_compress_steps) targets_c = compress(targets, inputs, False, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. latents_dense, latents_discrete, extra_loss, embed = hparams.bottleneck( x=targets_c, filter_size=hparams.compress_filter_size, name="vc", mode=hparams.mode) if _DO_SUMMARIES: tf.summary.histogram( "b0", tf.reshape(latents_discrete[:, 0, :], [-1])) pc = common_layers.inverse_exp_decay(hparams.startup_steps) pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([batch_size]), pc) latents_dense = tf.where(cond, latents_dense, targets_c) # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean. losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond)) # Extra loss predicting latent code from input. Discrete only. if hparams.bottleneck_kind not in ["dense", "vae"]: latents_pred = decode_transformer(inputs_ex, ed_ex, tf.stop_gradient( embed(latents_discrete)), hparams, "extra", task="translate") _, latent_pred_loss = ae_latent_softmax( latents_pred, tf.stop_gradient(latents_discrete), hparams) losses["latent_pred"] = tf.reduce_mean(latent_pred_loss * tf.to_float(cond)) else: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") losses["latent_pred"] = tf.reduce_mean( (inputs_c - targets_c)**2) * 20 def bn_inputs(): with tf.variable_scope(tf.get_variable_scope(), reuse=True): bn, _, _, _ = hparams.bottleneck( x=inputs_c, filter_size=hparams.compress_filter_size, name="vc", mode=hparams.mode) return bn inputs_c = bn_inputs ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5 ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 latents_dense = tf.where( tf.less(tf.random_uniform([batch_size]), ptc), latents_dense, inputs_c) else: if hparams.bottleneck_kind in ["dense", "vae"]: inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c") latents_dense, _, _, _ = hparams.bottleneck( x=inputs_c, filter_size=hparams.compress_filter_size, name="vc", mode=hparams.mode) else: latent_len = common_layers.shape_list(targets_c)[1] _, _, _, embed = hparams.bottleneck( x=targets_c, filter_size=hparams.compress_filter_size, name="vc") latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :]) if cache is None: cache = ae_latent_sample(latents_dense, inputs_ex, ed_ex, embed, 16, hparams) latents_dense = embed(cache) # Postprocess. d = latents_dense pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size]) pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :] latents_dense = tf.pad(latents_dense, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos # Masking. if hparams.do_mask: masking = common_layers.inverse_lin_decay( hparams.mask_startup_steps) masking *= common_layers.inverse_exp_decay( hparams.mask_startup_steps // 4) # Not much at start. if not hparams.do_refine: masking -= tf.random_uniform([]) * hparams.unmasked_percentage masking = tf.minimum(tf.maximum(masking, 0.0), 1.0) if hparams.use_predict_mask: masking = predict_mask if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask mask = tf.less( masking, tf.random_uniform(common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 3) for i in xrange(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j) if hparams.do_attend_decompress: d = attend(d, inputs, hparams, "decompress_attend_%d" % j) d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j) # targets is always [batch, length, 1, depth] targets = mask * targets + (1.0 - mask) * d # reshape back to 4d here if hparams.task == "image": targets = tf.reshape(targets, original_targets_shape) if hparams.task == "translate": targets = tf.concat([tf.reverse(latents_dense, [1]), targets], axis=1) res = decode_transformer(inputs, ed, targets, hparams, "decoder", causal=hparams.causal) if hparams.do_ae: if hparams.task == "translate": res = res[:, common_layers.shape_list(latents_dense)[1]:, :, :] if hparams.do_mask and hparams.do_refine: def refine_res(): # return residual_conv(res, 1, (5, 1), hparams, "refine") r, _ = encode(tf.squeeze(res, axis=[2]), target_space, hparams, "refine_enc") return tf.expand_dims(r, axis=2) masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3]) all_masked = tf.less(masked_batches, 0.1) res = tf.where(all_masked, refine_res(), res) # We'll start training the extra model of latents after mask_startup_steps. nonlatent_steps = hparams.mask_startup_steps latent_time = tf.less(nonlatent_steps, tf.to_int32(tf.train.get_global_step())) # Learning rate warmup for the latent model for 20K steps. latent_warmup = tf.to_float( tf.train.get_global_step()) - nonlatent_steps latent_warmup = tf.maximum(0.0, tf.minimum(1.0, latent_warmup / 20000.0)) losses["latent_pred"] *= tf.to_float(latent_time) * latent_warmup return res, losses, cache
def ae_transformer_internal(inputs, targets, target_space, hparams, cache=None): """Main step used for training.""" # Encoder. inputs = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs, target_space, hparams, "input_enc") # Autoencoding. losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)} max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1) targets, _ = common_layers.pad_to_same_length( targets, max_targets_len_from_inputs, final_length_divisible_by=2**hparams.num_compress_steps) targets_c = compress(targets, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. latents_discrete_hot, extra_loss = vq_discrete_bottleneck( x=targets_c, hparams=hparams) latents_dense = vq_discrete_unbottleneck(latents_discrete_hot, hparams=hparams) latents_dense = targets_c + tf.stop_gradient(latents_dense - targets_c) latents_discrete = tf.argmax(latents_discrete_hot, axis=-1) tf.summary.histogram("codes", tf.reshape(latents_discrete[:, 0, :], [-1])) losses["extra"] = extra_loss # Extra loss predicting latent code from input. latents_pred = decode_transformer(inputs, ed, latents_dense, hparams, "extra") latent_pred_loss = get_latent_pred_loss(latents_pred, latents_discrete_hot, hparams) losses["latent_pred"] = tf.reduce_mean(latent_pred_loss) else: latent_len = common_layers.shape_list(targets_c)[1] embed = functools.partial(vq_discrete_unbottleneck, hparams=hparams) latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :]) if cache is None: cache = ae_latent_sample_beam(latents_dense, inputs, ed, embed, hparams) cache_hot = tf.one_hot(cache, depth=2**hparams.bottleneck_bits) latents_dense = embed(cache_hot) # Postprocess. d = latents_dense pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size]) pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :] latents_dense = tf.pad(latents_dense, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos # Decompressing the dense latents for i in range(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j) d = decompress_step(d, hparams, i > 0, "decompress_%d" % j) masking = common_layers.inverse_lin_decay(hparams.mask_startup_steps) masking *= common_layers.inverse_exp_decay(hparams.mask_startup_steps // 4) # Not much at start. masking = tf.minimum(tf.maximum(masking, 0.0), 1.0) if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = 1.0 mask = tf.less(masking, tf.random_uniform(common_layers.shape_list(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 3) # targets is always [batch, length, 1, depth] targets = mask * targets + (1.0 - mask) * d res = decode_transformer(inputs, ed, targets, hparams, "decoder") latent_time = tf.less(hparams.mask_startup_steps, tf.to_int32(tf.train.get_global_step())) losses["latent_pred"] *= tf.to_float(latent_time) return res, losses, cache
def lstm_decoder_infer(self, inputs, sequence_length, hparams, clss, train, initial_state=None, bottleneck=None): # IN PREDICT MODE, RUN tf.while RNN max_decode_length = 51 batch_size = common_layers.shape_list(inputs)[0] zero_pad, logits_so_far = self.create_initial_input_for_decode(batch_size) layers = contrib_rnn.MultiRNNCell([ self.lstm_cell(hparams, train) for _ in range(hparams.num_hidden_layers) ]) if initial_state is None: raise Exception('initial state should be init from bottleneck!') # append one-hot class to bottleneck, which will be given per step clss = tf.reshape(clss, [-1]) if not hparams.use_cls: clss = tf.zeros_like(clss) if hparams.condition_on_sln: sln = tf.reshape(sequence_length, [-1]) bottleneck = tf.concat((bottleneck, tf.one_hot(clss, hparams.num_categories), tf.one_hot(sln, max_decode_length)), -1) else: bottleneck = tf.concat((bottleneck, tf.one_hot(clss, hparams.num_categories)), -1) def infer_step(logits_so_far, current_hidden): """Inference step of LSTM while loop.""" # unflatten hidden: current_hidden = tuple(tf.nn.rnn_cell.LSTMStateTuple(c=s[0], h=s[1]) for s in current_hidden) # put logits_so_far through top tm = self._problem_hparams.modality['targets'] # need to reuse top params reset_scope = tf.variable_scope(tf.VariableScope(tf.AUTO_REUSE, ''), reuse=tf.AUTO_REUSE, auxiliary_name_scope=False) top_scope = tf.variable_scope('svg_decoder/{}_modality'.format(tm), reuse=tf.AUTO_REUSE) with reset_scope, top_scope: samples_so_far = self.hparams.top['targets']( logits_so_far, None, self.hparams, self.problem_hparams.vocab_size) # append a zero pad to the samples. this effectively shifts the samples # right, but, unlike shift_right, by not removing the last element, we # allow an empty samples_so_far to not be empty after padding samples_so_far = tf.concat([zero_pad, samples_so_far], axis=1) shifted_targets = common_layers.flatten4d3d(samples_so_far) # now take the very last one here, will be the actual input to the rnn shifted_targets = shifted_targets[:, -1:, :] # tile and append the bottleneck to inputs sln_offset = 0 if hparams.condition_on_sln: sln_offset = 51 pre_tile_y = tf.reshape( bottleneck, [common_layers.shape_list(bottleneck)[0], 1, hparams.bottleneck_bits + hparams.num_categories + sln_offset]) overlay_x = tf.tile(pre_tile_y, [1, common_layers.shape_list(shifted_targets)[1], 1]) inputs = tf.concat([shifted_targets, overlay_x], -1) seq_len_batch = tf.ones([common_layers.shape_list(inputs)[0]]) # RUN PRE-LSTM LAYER with tf.variable_scope('pre_decoder', reuse=tf.AUTO_REUSE): inputs = tf.layers.dense(inputs, hparams.hidden_size, name='bottom') inputs = tf.nn.tanh(inputs) # RUN LSTM with tf.variable_scope('lstm_decoder', reuse=tf.AUTO_REUSE): next_step, next_state = tf.nn.dynamic_rnn( layers, inputs, seq_len_batch, initial_state=current_hidden, dtype=tf.float32, time_major=False) next_step = tf.expand_dims(next_step, [1]) logits_so_far = tf.concat([logits_so_far, next_step], 1) # flatten state next_state = tuple((s.c, s.h) for s in next_state) return logits_so_far, next_state def while_exit_cond(logits_so_far, unused_current_hidden): length = common_layers.shape_list(logits_so_far)[1] return length < max_decode_length # passing state must be flattened: initial_state = tuple([(s.c, s.h) for s in initial_state]) # actually run tf.while: logits, final_state = tf.while_loop( while_exit_cond, infer_step, [logits_so_far, initial_state], shape_invariants=[ tf.TensorShape([None, None, 1, hparams.hidden_size]), tuple([(s[0].get_shape(), s[1].get_shape()) for s in initial_state]), ], back_prop=False, parallel_iterations=1 ) # logits should be returned in 3d mode: logits = common_layers.flatten4d3d(logits) return logits, final_state
def body(self, features): """R-Transformer main model_fn. Args: features: Map of features to the model. Should contain the following: "inputs": Transformer inputs [batch_size, input_length, hidden_dim] "targets": Target decoder outputs. [batch_size, decoder_length, hidden_dim] "target_space_id" Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] """ hparams = self._hparams if hparams.add_position_timing_signal: # Turning off addition of positional embedding in the encoder/decoder # preparation as we do it in the beginning of each step. hparams.pos = None if self.has_input: inputs = features["inputs"] target_space = features["target_space_id"] (encoder_output, encoder_decoder_attention_bias, enc_extra_output) = self.encode(inputs, target_space, hparams, features=features) else: (encoder_output, encoder_decoder_attention_bias, enc_extra_output) = (None, None, (None, None)) targets = features["targets"] targets = common_layers.flatten4d3d(targets) (decoder_input, decoder_self_attention_bias ) = transformer.transformer_prepare_decoder(targets, hparams, features=features) decoder_output, dec_extra_output = self.decode( decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, "targets")) expected_attentions = features.get("expected_attentions") if expected_attentions is not None: attention_loss = common_attention.encoder_decoder_attention_loss( expected_attentions, self.attention_weights, hparams.expected_attention_loss_type, hparams.expected_attention_loss_multiplier) return decoder_output, {"attention_loss": attention_loss} if hparams.recurrence_type == "act" and hparams.act_loss_weight != 0: if self.has_input: enc_ponder_times, enc_remainders = enc_extra_output enc_act_loss = ( hparams.act_loss_weight * tf.reduce_mean(enc_ponder_times + enc_remainders)) else: enc_act_loss = 0.0 (dec_ponder_times, dec_remainders) = dec_extra_output dec_act_loss = (hparams.act_loss_weight * tf.reduce_mean(dec_ponder_times + dec_remainders)) act_loss = enc_act_loss + dec_act_loss tf.contrib.summary.scalar("act_loss", act_loss) return decoder_output, {"act_loss": act_loss} return decoder_output
def decode_transformer(encoder_output, encoder_decoder_attention_bias, targets, hparams, name, task=None, causal=True): """Original Transformer decoder.""" orig_hparams = hparams with tf.variable_scope(name): if task is None: task = hparams.task if task == "translate": targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_bias = ( transformer.transformer_prepare_decoder(targets, hparams)) decoder_input = tf.nn.dropout( decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) if not causal: decoder_self_bias *= 0. decoder_output = transformer.transformer_decoder( decoder_input, encoder_output, decoder_self_bias, encoder_decoder_attention_bias, hparams) decoder_output = tf.expand_dims(decoder_output, axis=2) else: assert task == "image" inputs = None # have to reshape targets as b, 32, 32, 3 * hidden size] beacuse otherwise # prepare_image will choke targets = tf.reshape(targets, [ tf.shape(targets)[0], hparams.img_len, hparams.img_len, hparams.num_channels * hparams.hidden_size ]) # Prepare decoder inputs and bias. # TODO(nikip): Make prepare_decoder return bias decoder_input, _, _ = cia.prepare_decoder(targets, hparams) bias = None # Add class label to decoder input. if not hparams.drop_inputs: decoder_input += tf.reshape(inputs, [ common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size ]) decoder_output = cia.transformer_decoder_layers( decoder_input, encoder_output=None, num_layers=hparams.num_decoder_layers or hparams.num_hidden_layers, hparams=hparams, self_attention_bias=bias, attention_type=hparams.dec_attention_type, name="decoder") decoder_output_shape = common_layers.shape_list(decoder_output) decoder_output = tf.reshape( decoder_output, [decoder_output_shape[0], -1, 1, hparams.hidden_size]) # Expand since t2t expects 4d tensors. hparams = orig_hparams return decoder_output
def body(self, features, original_features): """Transformer main model_fn. Args: features: Map of features to the model. Should contain the following: "inputs": Transformer inputs [batch_size, input_length, hidden_dim] "targets": Target decoder outputs. [batch_size, decoder_length, hidden_dim] "target_space_id" Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] """ hparams = self._hparams snippets = features.get(searchqa_problem.FeatureNames.SNIPPETS) questions = features.get(searchqa_problem.FeatureNames.QUESTION) target_space = features["target_space_id"] with tf.variable_scope('input'): # [batch_size, search_results_len, embed_sz] encoded_snippets = self.inputs_encoding( input=snippets, original_input=original_features.get( searchqa_problem.FeatureNames.SNIPPETS), initializer=tf.constant_initializer(1.0), scope='snippets_encoding') # [batch_size, 1, embed_sz] encoded_question = self.inputs_encoding( input=questions, original_input=original_features.get( searchqa_problem.FeatureNames.QUESTION), initializer=tf.constant_initializer(1.0), scope='question_encoding') # Concat snippets and questions to creat the inputs inputs = tf.concat([encoded_snippets, encoded_question], axis=1) # the input is 4D by default and it gets squeezed from 4D to 3D in the # encode function, so we need to make it 4D by inserting channel dim. inputs = tf.expand_dims(inputs, axis=2) losses = [] encoder_output, encoder_decoder_attention_bias = self.encode( inputs, target_space, hparams, features=features, losses=losses) targets = features["targets"] targets_shape = common_layers.shape_list(targets) targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_attention_bias = transformer.transformer_prepare_decoder( targets, hparams, features=features) decoder_output = self.decode(decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, nonpadding=features_to_nonpadding( features, "targets"), losses=losses) ret = tf.reshape(decoder_output, targets_shape) if losses: return ret, {"extra_loss": tf.add_n(losses)} else: return ret
def body(self, features): """Transformer main model_fn. Args: features: Map of features to the model. Should contain the following: "inputs": Transformer inputs [batch_size, input_length, hidden_dim] "targets": Target decoder outputs. [batch_size, decoder_length, hidden_dim] "target_space_id": A scalar int from data_generators.problem.SpaceID. Returns: Final decoder representation. [batch_size, decoder_length, hidden_dim] """ tf.logging.info("Using PgScratch BODY function.") hparams = self._hparams losses = {} inputs = features["inputs"] target_space = features["target_space_id"] # encoder_output: <tf.float32>[batch_size, input_length, hidden_dim] # encoder_decoder_attention_bias: <tf.float32>[batch_size, input_length] encoder_output, encoder_decoder_attention_bias = self.encode( inputs, target_space, hparams, features=features, losses=losses) with tf.variable_scope("knowledge"): with tf.name_scope("knowledge_encoding"): # Encode knowledge. # <tf.float32>[batch_size, triple_num, emb_dim] fact_embedding, fact_lengths = self.encode_knowledge_bottom( features) tf.logging.info("Encoded knowledge") with tf.name_scope("knowledge_selection_and_loss"): # Compute knowledge selection and loss. triple_logits, avg_triple_selection_loss, knowledge_encoder_output, transe_loss = self.compute_knowledge_selection_and_loss( features, encoder_output, fact_embedding, fact_lengths, hparams.margin, hparams.num_negative_samples) losses["kb_loss"] = avg_triple_selection_loss losses["transe_loss"] = transe_loss if hparams.attend_kb: tf.logging.info("ATTEND_KB is ACTIVE") with tf.name_scope("knowledge_attention"): knowledge_padding = tf.zeros_like(triple_logits, dtype=tf.float32) knowledge_attention_bias = common_attention.attention_bias_ignore_padding( knowledge_padding) encoder_output = tf.concat( [knowledge_encoder_output, encoder_output], 1) encoder_decoder_attention_bias = tf.concat( [knowledge_attention_bias, encoder_decoder_attention_bias], -1) else: tf.logging.info("ATTEND_KB is INACTIVE") targets = features["targets"] targets_shape = common_layers.shape_list(targets) targets = common_layers.flatten4d3d(targets) (decoder_input, decoder_self_attention_bias ) = transformer.transformer_prepare_decoder(targets, hparams, features=features) decode_kwargs = {} decoder_output = self.decode( decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, nonpadding=transformer.features_to_nonpadding(features, "targets"), losses=losses, **decode_kwargs) expected_attentions = features.get("expected_attentions") if expected_attentions is not None: attention_loss = common_attention.encoder_decoder_attention_loss( expected_attentions, self.attention_weights, hparams.expected_attention_loss_type, hparams.expected_attention_loss_multiplier) return decoder_output, {"attention_loss": attention_loss} ret = tf.reshape(decoder_output, targets_shape) if losses: return ret, losses else: return ret
def ae_transformer_internal(inputs, targets, target_space, hparams, beam_size, cache=None, predict_mask=1.0): """AE Transformer, main step used for training.""" # Summaries break with the do_refine cond, turn them off in that case. global _DO_SUMMARIES if hparams.do_refine: _DO_SUMMARIES = False # Prepare. orig_targets = targets batch_size = tf.shape(orig_targets)[0] targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size]) # Encoder. if inputs is not None: inputs = common_layers.flatten4d3d(inputs) inputs, ed = encode(inputs, target_space, hparams, "input_enc") else: ed = None # Autoencoding. losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)} if hparams.do_ae: max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1) targets, _ = common_layers.pad_to_same_length( targets, max_targets_len_from_inputs, final_length_divisible_by=2**hparams.num_compress_steps) targets_c = compress(targets, False, hparams, "compress") if hparams.mode != tf.estimator.ModeKeys.PREDICT: # Compress and bottleneck. t_c, t_bit, vc_loss, _ = bottleneck(targets_c, hparams, 2 * 2048, "vc") if _DO_SUMMARIES: tf.summary.histogram("bit0", tf.reshape(t_bit[:, 0, :], [-1])) pc = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.95 pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 cond = tf.less(tf.random_uniform([]), pc) t_c = tf.cond(cond, lambda: t_c, lambda: targets_c) losses["extra"] = vc_loss * tf.to_float(cond) # Extra loss predicting latent code from input. Discrete only. if hparams.bottleneck_kind not in ["dense", "vae"]: t_pred = decode_transformer(inputs, ed, tf.stop_gradient(t_c), hparams, "extra") t_pred = tf.layers.dense(t_pred, 2**16, name="extra_logits") losses[ "latent_pred"] = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=t_bit, logits=t_pred) losses["latent_pred"] = tf.reduce_mean( losses["latent_pred"]) * 0.5 * tf.to_float(cond) else: if hparams.bottleneck_kind in ["dense", "vae"]: targets_rand = tf.random_uniform(tf.shape(targets_c)) t_c, _, _, _ = bottleneck(targets_rand, hparams, 2 * 2048, "vc") else: latent_len = tf.shape(targets_c)[1] _, _, _, embed = bottleneck(targets_c, hparams, 2 * 2048, "vc") t_c = tf.zeros_like(targets_c[:, :latent_len, :, :]) if cache is None: cache = ae_latent_sample(t_c, inputs, ed, embed, 8, hparams) cache = cache[0, :, :] cache = tf.reshape(cache, [1, latent_len, 1]) cache = tf.tile(cache, [beam_size, 1, 1]) t_c = embed(cache) # Postprocess. d = t_c pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size]) pos = pos[:, :tf.shape(t_c)[1] + 1, :, :] t_c = tf.pad(t_c, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos # Masking. if hparams.do_mask: masking = common_layers.inverse_lin_decay(100000) masking *= common_layers.inverse_exp_decay( 25000) # Not much at start. if not hparams.do_refine: masking -= tf.random_uniform([]) * 0.3 masking = tf.minimum(tf.maximum(masking, 0.0), 1.0) if hparams.mode == tf.estimator.ModeKeys.PREDICT: masking = predict_mask mask = tf.less(masking, tf.random_uniform(tf.shape(targets)[:-1])) mask = tf.expand_dims(tf.to_float(mask), 3) for i in xrange(hparams.num_compress_steps): j = hparams.num_compress_steps - i - 1 d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j) d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j) targets = mask * targets + (1.0 - mask) * d targets = tf.concat([tf.reverse(t_c, [1]), targets], axis=1) res = decode_transformer(inputs, ed, targets, hparams, "decoder") if hparams.do_ae: res = res[:, tf.shape(t_c)[1]:, :, :] if hparams.do_mask and hparams.do_refine: def refine_res(): return residual_conv(res, 1, (5, 1), hparams, "refine") all_masked = tf.less(tf.reduce_sum(mask), 0.1) res = tf.cond(all_masked, refine_res, lambda: res) return res, losses, cache
def mel_perf_transformer_encode(encoder_function, perf_inputs, mel_inputs, target_space, hparams, attention_weights=None, features=None, losses=None, prepare_encoder_fn=None, **kwargs): """Encode transformer inputs. Used for melody & performance autoencoder. Performance is mean-aggregated across time and combined with melody in a variety of different ways. Args: encoder_function: the encoder function perf_inputs: Transformer inputs [batch_size, input_length, 1, hidden_dim] which will be flattened along the two spatial dimensions. mel_inputs: Transformer inputs [batch_size, input_length, 1, hidden_dim] which will be flattened along the two spatial dimensions. target_space: scalar, target space ID. hparams: hyperparameters for model. attention_weights: weight to store attention to. features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. losses: optional list onto which to append extra training losses prepare_encoder_fn: optional, alternative to transformer_prepare_encoder. **kwargs: additional arguments to pass to encoder_function Returns: Tuple of: encoder_output: Encoder representation. [batch_size, input_length, hidden_dim] encoder_decoder_attention_bias: Bias and mask weights for encoder-decoder attention. [batch_size, input_length] """ perf_inputs = common_layers.flatten4d3d(perf_inputs) mel_inputs = common_layers.flatten4d3d(mel_inputs) if not prepare_encoder_fn: prepare_encoder_fn = transformer_prepare_encoder perf_encoder_input, perf_self_attention_bias, perf_encdec_attention_bias = ( prepare_encoder_fn(perf_inputs, target_space, hparams, features=features, reuse_target_embedding=tf.AUTO_REUSE)) mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT, value=hparams.layer_prepostprocess_dropout, hparams=hparams) perf_encoder_input = tf.nn.dropout( perf_encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) perf_attn_bias_for_padding = None # Otherwise the encoder will just use encoder_self_attention_bias. if hparams.unidirectional_encoder: perf_attn_bias_for_padding = perf_encdec_attention_bias # do the same thing for melody mel_encoder_input, mel_self_attention_bias, mel_encdec_attention_bias = ( prepare_encoder_fn(mel_inputs, target_space, hparams, features=features, reuse_target_embedding=tf.AUTO_REUSE)) mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT, value=hparams.layer_prepostprocess_dropout, hparams=hparams) mel_encoder_input = tf.nn.dropout( mel_encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) mel_attn_bias_for_padding = None # Otherwise the encoder will just use encoder_self_attention_bias. if hparams.unidirectional_encoder: mel_attn_bias_for_padding = mel_encdec_attention_bias # use the proper encoder function for perf/melody perf_encoder_output = encoder_function( perf_encoder_input, perf_self_attention_bias, hparams, name="perf_encoder", nonpadding=features_to_nonpadding(features, "inputs"), save_weights_to=attention_weights, make_image_summary=not common_layers.is_xla_compiled(), losses=losses, attn_bias_for_padding=perf_attn_bias_for_padding, **kwargs) # same thing for melody mel_encoder_output = encoder_function( mel_encoder_input, mel_self_attention_bias, hparams, name="mel_encoder", nonpadding=features_to_nonpadding(features, "inputs"), save_weights_to=attention_weights, make_image_summary=not common_layers.is_xla_compiled(), losses=losses, attn_bias_for_padding=mel_attn_bias_for_padding, **kwargs) # concatenate the global mean vector/bias term with the full melody encoding perf_mean_vector = tf.math.reduce_mean(perf_encoder_output, axis=1, keep_dims=True) # different methods of aggregating over the performance + melody vectors! if hparams.aggregation == "sum": # add both mean performance and melody vectors together perf_mean_bias = tf.math.reduce_mean(perf_encdec_attention_bias, axis=-1, keep_dims=True) encoder_output = mel_encoder_output + perf_mean_vector encoder_decoder_attention_bias = mel_encdec_attention_bias + perf_mean_bias elif hparams.aggregation == "concat": # concatenate melody with mean-aggregated performance embedding stop_token = tf.zeros((1, 1, 384)) encoder_output = tf.concat( [mel_encoder_output, stop_token, perf_mean_vector], axis=1) perf_mean_bias = tf.math.reduce_mean(perf_encdec_attention_bias, axis=-1, keep_dims=True) stop_bias = tf.zeros((1, 1, 1, 1)) encoder_decoder_attention_bias = tf.concat( [mel_encdec_attention_bias, stop_bias, perf_mean_bias], axis=-1) elif hparams.aggregation == "tile": # tile performance embedding across each dimension of melody embedding! dynamic_val = tf.shape(mel_encoder_output)[1] shp = tf.convert_to_tensor([1, dynamic_val, 1], dtype=tf.int32) tiled_mean = tf.tile(perf_mean_vector, shp) encoder_output = tf.concat([mel_encoder_output, tiled_mean], axis=-1) encoder_decoder_attention_bias = mel_encdec_attention_bias else: NotImplementedError( "aggregation method must be in [sum, concat, tile].") return encoder_output, encoder_decoder_attention_bias
def maybe_flatten4d3d(x): xshape = common_layers.shape_list(x) return common_layers.flatten4d3d(x) if len(xshape) == 4 else x