def create_positional_emb_2d(self, targets, max_length_dim, model_dim): """Learned 2d positional embedding for images.""" mesh = targets.mesh hparams = self._hparams activation_dtype = self.set_activation_type() rows_dim = mtf.Dimension("rows", hparams.img_len) cols_dim = mtf.Dimension("cols", hparams.img_len * hparams.num_channels) positional_emb_rows_var = mtf.get_variable( mesh, "positional_emb_rows", mtf.Shape([max_length_dim, model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=activation_dtype) positional_emb_cols_var = mtf.get_variable( mesh, "positional_emb_cols", mtf.Shape([max_length_dim, model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=activation_dtype) targets_position_x = mtf.range(mesh, rows_dim, dtype=tf.int32) targets_position_y = mtf.range(mesh, cols_dim, dtype=tf.int32) position_x = mtf.broadcast( mtf.gather(positional_emb_rows_var, targets_position_x, max_length_dim), mtf.Shape([rows_dim, cols_dim, model_dim])) position_y = mtf.broadcast( mtf.gather(positional_emb_cols_var, targets_position_y, max_length_dim), mtf.Shape([rows_dim, cols_dim, model_dim])) return position_x + position_y
def create_positional_emb_2d(self, targets): """Learned 2d positional embedding for images.""" mesh = targets.mesh positional_emb_rows_var = mtf.get_variable( mesh, "positional_emb_rows", mtf.Shape([self.max_length_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=self.activation_type) positional_emb_cols_var = mtf.get_variable( mesh, "positional_emb_cols", mtf.Shape([self.max_length_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=self.activation_type) targets_position_x = mtf.range(mesh, self.rows_dim, dtype=tf.int32) targets_position_y = mtf.range(mesh, self.cols_dim, dtype=tf.int32) position_x = mtf.broadcast( mtf.gather(positional_emb_rows_var, targets_position_x, self.max_length_dim), mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim])) position_y = mtf.broadcast( mtf.gather(positional_emb_cols_var, targets_position_y, self.max_length_dim), mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim])) return position_x + position_y
def my_gather(tensor): return mtf.gather(tensor, top_beam_index, beam_dim, output_shape=mtf.Shape([ double_beam if d == beam_dim else d for d in tensor.shape.dims ]))
def gather(tensor, name): with tf.name_scope(prefix + name): output_shape = mtf.Shape([ beam_dim if d == old_beam_dim else d for d in tensor.shape.dims ]) return mtf.gather(tensor, topk_indices, old_beam_dim, output_shape=output_shape)
def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" self_attention_k = states[:hparams.num_decoder_layers] self_attention_v = states[hparams.num_decoder_layers:] ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_self_attention_k, new_self_attention_v = ( self._decoder_layer_stack_incremental( x, step_num, encdec_tensors, self_attention_k, self_attention_v, encdec_attention_mask=encoder_attention_mask)) logits = mtf.matmul(x, softmax_var) return logits, new_self_attention_k + new_self_attention_v
def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq, finished_scores, finished_in_finished, *unused_states): """Checking termination condition. We terminate when we decoded up to decode_length or the lowest scoring item in finished has a greater score that the highest prob item in alive divided by the max length penalty Args: i: loop index alive_log_probs: probabilities of the beams. [batch_size, beam_size] finished_scores: scores for each of these sequences. [batch_size, beam_size] finished_in_finished: finished bools for each of these sequences. [batch_size, beam_size] Returns: Bool. """ # TODO(noam): support a different decode length... # decode_length = mtf.constant(mesh, length_dim.size, dtype=tf.int32) # del alive_log_probs, finished_scores, finished_in_finished # return mtf.less(i, length_dim.size) if not stop_early: return mtf.less(i, decode_length) max_length_penalty = mtf.pow(((5. + mtf.to_float(decode_length)) / 6.), alpha) # The best possible score of the most likely alive sequence. lower_bound_alive_scores = mtf.gather( alive_log_probs, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim) / max_length_penalty # Now to compute the lowest score of a finished sequence in finished # If the sequence isn't finished, we multiply it's score by 0. since # scores are all -ve, taking the min will give us the score of the lowest # finished item. lowest_score_of_finished_in_finished = mtf.reduce_min( finished_scores * mtf.to_float(finished_in_finished), reduced_dim=beam_dim) # If none of the sequences have finished, then the min will be 0 and # we have to replace it by -ve INF if it is. The score of any seq in alive # will be much higher than -ve INF and the termination condition will not # be met. lowest_score_of_finished_in_finished += ((1. - mtf.to_float( mtf.reduce_any(finished_in_finished, reduced_dim=beam_dim))) * -INF) bound_is_met = mtf.reduce_all( mtf.greater(lowest_score_of_finished_in_finished, lower_bound_alive_scores)) return mtf.logical_and(mtf.less(i, decode_length), mtf.logical_not(bound_is_met))
def body_fn(step_num, ids, *states): """Body function for greedy decoding. Args: step_num: a mtf.Tensor ids: a mtf.Tensor *states: additional mtf.Tensors Returns: new_step_num, new_ids, *new_states """ logits, new_states = logits_fn(step_num, ids, states) vocab_dim = logits.shape.dims[-1] new_ids = mtf.sample_with_temperature(logits, vocab_dim, temperature) if forced_ids is not None: # force the new ids to equal the partial targets where specified # (positions where partial_targets contain nonzero values) forced = mtf.gather(forced_ids, step_num, length_dim) new_ids = forced + new_ids * mtf.to_int32(mtf.equal(forced, 0)) ids += new_ids * mtf.one_hot(step_num, length_dim, dtype=tf.int32) new_step_num = step_num + 1 return [new_step_num, ids] + new_states
def _sample(self, features, mesh): hparams = self._hparams (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if self.has_input: inputs = features["inputs"] while len(inputs.shape.as_list()) > 2: inputs = tf.squeeze(inputs, axis=2) actual_batch_size = tf.shape(inputs)[0] actual_length = tf.shape(inputs)[1] inputs = tf.pad( inputs, [[0, hparams.batch_size - actual_batch_size], [0, hparams.max_length - actual_length]]) inputs = self._import_to_batch_by_length( inputs, "inputs", mesh, hparams) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.reshape(positional_embedding_var, mtf.Shape([self.length_dim, self.model_dim]))) encoder_attention_mask = ( mtf_layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.num_encoder_layers, self_attention_mask=encoder_attention_mask) encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) encdec_tensors = [] for layer_num in xrange(hparams.num_decoder_layers): with tf.variable_scope("decoder/layer_%d/encdec_attention" % layer_num): q_var, k_var, v_var, o_var = mtf_layers.multihead_attention_vars( mesh, self.heads_dim, self.model_dim, self.kv_dim, self.activation_dtype) k = mtf.einsum( [encoder_output, k_var], mtf.Shape( [self.batch_dim, self.heads_dim, self.memory_length_dim, self.kv_dim])) v = mtf.einsum( [encoder_output, v_var], mtf.Shape( [self.batch_dim, self.heads_dim, self.memory_length_dim, self.kv_dim])) encdec_tensors.append((q_var, o_var, k, v)) partial_targets = None else: encdec_tensors = None encoder_output = None encoder_attention_mask = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, hparams.max_length - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh, hparams) if hparams.beam_size == 1: ids_shape = mtf.Shape([self.batch_dim, self.length_dim]) kv_shape = mtf.Shape([self.batch_dim, self.heads_dim, self.memory_length_dim, self.kv_dim]) else: beam_dim = mtf.Dimension("beam", hparams.beam_size) ids_shape = mtf.Shape([self.batch_dim, beam_dim, self.length_dim]) kv_shape = mtf.Shape([self.batch_dim, beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim]) initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) initial_kv_states = ( [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * (2 * hparams.num_decoder_layers)) def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" self_attention_k = states[:hparams.num_decoder_layers] self_attention_v = states[hparams.num_decoder_layers:] ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_self_attention_k, new_self_attention_v = ( self._decoder_layer_stack_incremental( x, step_num, encdec_tensors, self_attention_k, self_attention_v, encdec_attention_mask=encoder_attention_mask)) logits = mtf.matmul(x, softmax_var) return logits, new_self_attention_k + new_self_attention_v if hparams.beam_size == 1: temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) return mtf_beam_search.greedy_decode( logits_fn, initial_ids, temperature=temperature, initial_states=initial_kv_states, forced_ids=partial_targets, use_tpu=hparams.use_tpu) else: if self.has_input: input_length = mtf.reduce_sum( mtf.to_float(mtf.cast(inputs, tf.bool)), reduced_dim=self.length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * hparams.decode_length_multiplier + hparams.decode_length_constant, tf.int32) else: decode_length = None beams, unused_scores = mtf_beam_search.beam_search( logits_fn, initial_ids, hparams.alpha, states=initial_kv_states, decode_length=decode_length, use_tpu=hparams.use_tpu) return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
def _mtf_model_fn(self, features, mesh): features = copy.copy(features) hparams = self._hparams targets = tf.to_int32(features["targets"]) if len(targets.get_shape()) > 2: tf.logging.info("targets = %s" % targets) targets = tf.squeeze(targets, [2, 3]) # pad targets to max_length def pad_to_max_length(x): extra_length = hparams.max_length - tf.shape(x)[1] x = tf.pad(x, [[0, 0], [0, extra_length]]) x = tf.reshape(x, [hparams.batch_size, hparams.max_length]) return x targets = pad_to_max_length(targets) for key in ["targets_segmentation", "targets_position", "inputs_segmentation", "inputs_position"]: if key in features: features[key] = pad_to_max_length(features[key]) shifted_targets = common_layers.shift_right_2d(targets) targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams) shifted_targets = self._import_to_batch_by_length( shifted_targets, "shifted_targets", mesh, hparams) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = self._import_to_batch_by_length( features["targets_segmentation"], "targets_segmentation", mesh, hparams) targets_position = self._import_to_batch_by_length( features["targets_position"], "targets_position", mesh, hparams) decoder_self_attention_mask = ( mtf_layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) + mtf_layers.attention_mask_same_segment( targets_segmentation, dtype=self.activation_dtype)) else: targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) decoder_self_attention_mask = mtf_layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) def layer_prepostprocess_dropout(x): return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape([self.batch_dim, self.model_dim])) extra_losses = [] (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if self.has_input: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = pad_to_max_length(inputs) inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = self._import_to_batch_by_length( features["inputs_segmentation"], "inputs_segmentation", mesh, hparams) inputs_position = self._import_to_batch_by_length( features["inputs_position"], "inputs_position", mesh, hparams) encoder_self_attention_mask = ( mtf_layers.attention_mask_same_segment( inputs_segmentation, dtype=self.activation_dtype)) encoder_decoder_attention_mask = ( mtf_layers.attention_mask_same_segment( targets_segmentation, inputs_segmentation, dtype=self.activation_dtype)) else: inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) encoder_self_attention_mask = ( mtf_layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) encoder_decoder_attention_mask = encoder_self_attention_mask x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.gather(positional_embedding_var, inputs_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.num_encoder_layers, self_attention_mask=encoder_self_attention_mask, losses=extra_losses) encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) else: encoder_output = None encoder_decoder_attention_mask = None # DECODER x = (mtf.gather( targets_embedding_var, shifted_targets, self.targets_vocab_dim) + mtf.gather( positional_embedding_var, targets_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) # Decoder with tf.variable_scope("decoder"): x = self._layer_stack( x, hparams.num_decoder_layers, encoder_output=encoder_output, self_attention_mask=decoder_self_attention_mask, encdec_attention_mask=encoder_decoder_attention_mask, losses=extra_losses) logits = mtf.matmul(x, softmax_var) off_value = hparams.label_smoothing / self._targets_vocab_size on_value = 1.0 - hparams.label_smoothing + off_value soft_targets = mtf.one_hot( targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value, dtype=self.activation_dtype) loss = mtf_layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.targets_vocab_dim) weights = mtf_layers.weights_nonzero( targets, dtype=self.activation_dtype) loss = mtf.reduce_mean(loss * weights) for l in extra_losses: loss += l return logits, loss
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.set_activation_type() # We assume fixed vocab size for targets targets_vocab_size = self._problem_hparams.target_modality._vocab_size # pylint: disable=protected-access targets = tf.to_int32(features["targets"]) # Image preprocessing, reshape into a 1D sequence and shift right. length = hparams.img_len * hparams.img_len * hparams.num_channels targets = tf.reshape(targets, [hparams.batch_size, length]) shifted_targets = common_layers.shift_right_2d(targets) # Declare all the dimensions model_dim = mtf.Dimension("d_model", hparams.hidden_size) batch_dim = mtf.Dimension("batch", hparams.batch_size) length_dim = mtf.Dimension("length", length) max_length_dim = mtf.Dimension("max_length", hparams.max_length) filter_dim = mtf.Dimension("d_ff", hparams.d_ff) kv_channels = mtf.Dimension("kv_channels", hparams.d_kv) heads = mtf.Dimension("heads", hparams.num_heads) def import_to_batch_by_length(x, name): return mtf.import_tf_tensor(mesh, x, mtf.Shape([batch_dim, length_dim]), name=name) def layer_prepostprocess_dropout(x): return mtf.dropout(x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape([batch_dim, model_dim])) targets = import_to_batch_by_length(targets, "targets") shifted_targets = import_to_batch_by_length(shifted_targets, "shifted_targets") extra_losses = [] # Create targets content and position embeddings. targets_vocab_size = 256 * hparams.num_channels targets_vocab_dim = mtf.Dimension("vocab", targets_vocab_size) outputs_vocab_dim = mtf.Dimension("output_vocab", 256) # Create embedding var for targets and positions and do a gather. targets_embedding_var = mtf.get_variable( mesh, "targets_embedding", mtf.Shape([targets_vocab_dim, model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=activation_dtype) x = mtf.gather(targets_embedding_var, shifted_targets, targets_vocab_dim) # Add positional embeddings x += mtf.reshape( self.create_positional_emb_2d(targets, max_length_dim, model_dim), [length_dim, model_dim]) # If conditional and input is given, add the input embedding to the target. # TODO(nikip): Verify conditional. if self.has_input and not hparams.unconditional: vocab_size = hparams.num_classes inputs_vocab_dim = mtf.Dimension("vocab", vocab_size) inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = import_to_batch_by_length(inputs, "inputs") # Input embeddings inputs_embedding_var = mtf_layers.embedding( mesh, "input_embedding", mtf.Shape([inputs_vocab_dim, model_dim]), activation_dtype=activation_dtype) inputs_emb = mtf.gather(inputs_embedding_var, inputs, inputs_vocab_dim) x += inputs_emb # Image Transformer Decoder # [ self attention - ffn - residual + dropout] x n for layer in range(hparams.num_decoder_layers): layer_name = "decoder_layer_%d" % layer with tf.variable_scope(layer_name): # Self attention layer x += layer_prepostprocess_dropout( mtf_layers.masked_local_attention_1d( mtf_layers.layer_norm(x, model_dim, name="layer_norm_self_att"), None, kv_channels, heads, block_length=hparams.block_length, name="self_att")) # ffn layer x += layer_prepostprocess_dropout( mtf_layers.dense_relu_dense( mtf_layers.layer_norm(x, model_dim, name="layer_norm_ffn"), filter_dim, hparams.dropout, dropout_broadcast_dims=[length_dim])) x = mtf_layers.layer_norm(x, model_dim, name="decoder_final_layer_norm") # Calculate the logits and loss. logits = mtf_layers.dense(x, outputs_vocab_dim, name="logits") soft_targets = mtf.one_hot(targets, outputs_vocab_dim, dtype=activation_dtype) loss = mtf_layers.softmax_cross_entropy_with_logits( logits, soft_targets, outputs_vocab_dim) loss = mtf.reduce_mean(loss) for l in extra_losses: loss += l return logits, loss