def ids_to_embedding(self, ids): if self._is_factorized: tmp = mtf.gather(self._factor1, ids, self._vocab_dim) return mtf.einsum([tmp, self._factor2], reduced_dims=[self._inner_dim]) else: return mtf.gather(self._embedding_weights, ids, self._vocab_dim)
def create_positional_emb_2d(self, targets): """Learned 2d positional embedding for images.""" mesh = targets.mesh positional_emb_rows_var = mtf.get_variable( mesh, "positional_emb_rows", mtf.Shape([self.pos_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=self.activation_type) positional_emb_cols_var = mtf.get_variable( mesh, "positional_emb_cols", mtf.Shape([self.pos_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=self.activation_type) targets_position_x = mtf.range(mesh, self.rows_dim, dtype=tf.int32) targets_position_y = mtf.range(mesh, self.cols_dim, dtype=tf.int32) position_x = mtf.broadcast( mtf.gather(positional_emb_rows_var, targets_position_x, self.pos_dim), mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim])) position_y = mtf.broadcast( mtf.gather(positional_emb_cols_var, targets_position_y, self.pos_dim), mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim])) return position_x + position_y
def widedeep(id_hldr, wt_hldr, vocab_dim, embed_dim, outdim, float16=None): logger.debug("[input tensor] (name,shape):({},{})".format( id_hldr.name, id_hldr.shape)) logger.debug("[input tensor] (name,shape):({},{})".format( wt_hldr.name, wt_hldr.shape)) if float16: embedding_table = mtf.layers.embedding_weights(id_hldr.mesh, vocab_dim, embed_dim, float16, "deep_embedding") deep_output = mtf.gather(embedding_table, id_hldr, vocab_dim) # deep_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim, variable_dtype=float16, name="deep_embedding") else: fp32 = mtf.VariableDType(tf.float32, tf.float32, tf.float32) embedding_table = mtf.layers.embedding_weights(id_hldr.mesh, vocab_dim, embed_dim, fp32, "deep_embedding") deep_output = mtf.gather(embedding_table, id_hldr, vocab_dim) # deep_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim, variable_dtype=fp32, name="deep_embedding") logger.debug("[output tensor] (name,shape):({},{})".format( deep_output.name, deep_output.shape)) expend_dim = mtf.Dimension('expend', size=1) embed_dim_one = mtf.Dimension('embed_dim_one', size=1) mask = mtf.reshape( wt_hldr, new_shape=[wt_hldr.shape.dims[0], wt_hldr.shape.dims[1], expend_dim], name='mask_reshape') logger.debug("[output tensor] (name,shape):({},{})".format( mask.name, mask.shape)) if float16: wide_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim_one, variable_dtype=float16, name="wide_embedding") else: fp32 = mtf.VariableDType(tf.float32, tf.float32, tf.float32) wide_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim_one, variable_dtype=fp32, name="wide_embedding") logger.debug("[output tensor] (name,shape):({},{})".format( wide_output.name, wide_output.shape)) wide_output = wide(wide_output, mask=mask, float16=float16) deep_output = deep(deep_output, mask=mask, float16=float16) result = mtf.add(wide_output, deep_output) result = mtf.reshape(result, new_shape=[wide_output.shape.dims[0], outdim], name='result_reshape') result = mtf.reduce_sum(result, reduced_dim=outdim) logger.debug("[output tensor] (name,shape):({},{})".format( result.name, result.shape)) return result, embedding_table
def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_states = self._layer_stack( x, hparams.decoder_layers, encdec_attention_mask=encoder_attention_mask, step_num=step_num, encdec_tensors=encdec_tensors, states=states) logits = mtf.matmul(x, softmax_var) return logits, new_states
def forward(self, features, return_loss=True, return_logits=False): inputs = features["tokens"] tokens = self.positional_embedding(self.embedding(inputs, "embedding"), "positional_embedding") mask = self.get_attn_mask(tokens.mesh, tokens.shape[1], self.dimensions["memory_len_dim"]) out = self.transformer(tokens, mask=mask) logits = self.to_logits(out) if not return_loss: return logits labels = pad(inputs, [0, 1], dim_name="total_seq_dim", pad_value=self.eos_token_id) indices = mtf.range(labels.mesh, mtf.Dimension("range", labels.shape[1].size - 1), tf.int32, name="labels_indices") + 1 labels = mtf.gather(labels, indices, dim=labels.shape[1]) labels = mtf.rename_dimension(labels, "range", "total_seq_dim") loss, loss_batch = self._loss(logits, labels) if return_logits and return_loss: # Cast back to checkpoint dtype logits = mtf.cast(logits, self.variable_dtype.master_dtype) return loss, loss_batch, logits return loss, loss_batch
def body_fn(position, ids, *states): """One step in the decode loop.""" context_incremental = Context( mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="incremental", autoregressive=self.autoregressive, position=position, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) inputs_this_step = mtf.gather(ids, position - 1, length_dim) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step) ids_this_step = mtf.sample_with_temperature( logits, self.output_vocab_dim, temperature) new_position = position + 1 new_ids = ids + ids_this_step * mtf.one_hot( position, length_dim, dtype=tf.int32) return [new_position, new_ids] + context_incremental.new_states
def get_layer_timing_signal_learned_1d(self, context, channels, layer, num_layers): """get n-dimensional embedding as the layer (vertical) timing signal. Adds embeddings to represent the position of the layer in the tower. Args: context: mtf context channels: dimension of the timing signal layer: layer num num_layers: total number of layers Returns: a Tensor of timing signals [channels]. """ layer_dim = mtf.Dimension("layer", num_layers) shape = mtf.Shape([layer_dim, channels]) layer_embedding = (mtf.get_variable( context.mesh, "layer_embedding", shape, dtype=context.variable_dtype, initializer=tf.random_normal_initializer(0, channels.size**-0.5)) * (channels.size**0.5)) return mtf.gather(layer_embedding, layer, layer_dim)
def get_indices(self, keys: mtf.Tensor, query: mtf.Tensor) -> Tuple[mtf.Tensor, mtf.Tensor]: """Generate score and indices for the query.""" score_shape = mtf.Shape(query.shape.dims[:-1] + keys.shape.dims[2:3]) scores = mtf.einsum([query, keys], output_shape=score_shape) # [b, l, h, 2, n_keys] knn_dim = mtf.Dimension("knn", self.knn) scores, indices = mtf.top_k(scores, score_shape.dims[-1], knn_dim) # [b, l, h, 2, knn] # Computes the top cartesian products and their indices knn_square_dim = mtf.Dimension("knn_square_dim", self.knn**2) scores1, scores2 = mtf.unstack(scores, scores.shape.dims[-2]) scores2 = mtf.rename_dimension(scores2, "knn", "knn2") out_shape = mtf.Shape(scores1.shape.dims + scores2.shape.dims[-1:]) all_scores = mtf.add(scores1, scores2, output_shape=out_shape) all_scores = mtf.replace_dimensions(all_scores, out_shape[-2:], knn_square_dim) indices1, indices2 = mtf.unstack(indices, indices.shape.dims[-2]) indices1 = mtf.multiply(indices1, self.n_keys) indices2 = mtf.rename_dimension(indices2, "knn", "knn2") all_indices = mtf.add(indices1, indices2, output_shape=out_shape) all_indices = mtf.replace_dimensions(all_indices, out_shape[-2:], knn_square_dim) scores, best_indices = mtf.top_k(all_scores, all_scores.shape.dims[-1], knn_dim) return scores, mtf.gather(all_indices, best_indices, knn_square_dim)
def logits_fn(step_num, ids, states): """logits_fn for mtf.beam_search.beam_search().""" context_incremental = Context( mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="incremental", autoregressive=self.autoregressive, position=step_num, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) inputs_this_step = mtf.gather(ids, step_num - 1, length_dim) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step) return mtf.to_float(logits), context_incremental.new_states
def add_tokens_embeddings(self, tokens): """ tokens: [batch_size, seq_len, 1] """ # vocab_size = self.config.vocab_size # embedding_dim = self.config.embedding_dim # embedding_table = self.add_var( # "tok_embeddings", shape=(vocab_size, embedding_dim) # ) # flat_input_ids = tf.reshape(tokens, [-1]) # one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size) # # [ vocab_size, embedding_dim ] * [vocab_size, embedding_dim] # output = tf.matmul(one_hot_input_ids, embedding_table) # output = tf.reshape(output, shape=(-1, sequence_length, embedding_dim)) with tf.variable_scope("embeddings"): # Perform embedding lookup on the token ids. self.embedding_table = self.get_variable( "tokens_embeddings", [self.vocab_dim, self.model_dim], initializer=self.embedding_initializer, ) self.tokens_embedding_output = mtf.gather( self.embedding_table, tokens, self.vocab_dim ) return self.tokens_embedding_output
def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" self_attention_k = states[:hparams.num_decoder_layers] self_attention_v = states[hparams.num_decoder_layers:] ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_self_attention_k, new_self_attention_v = ( self._decoder_layer_stack_incremental( x, step_num, encdec_tensors, self_attention_k, self_attention_v, encdec_attention_mask=encoder_attention_mask)) logits = mtf.matmul(x, softmax_var) return logits, new_self_attention_k + new_self_attention_v
def transformer_lm(x, *, dim, num_tokens, depth, max_seq_len, dim_head, dim_features_head, causal=False): mesh, batch, seq_dim = x.mesh, *x.shape dim = mtf.Dimension('dim', dim) dim_head = mtf.Dimension('dim_head', dim_head) dim_features_head = mtf.Dimension('dim_features_head', dim_features_head) dim_num_tokens = mtf.Dimension('vocab_size', num_tokens) dim_max_seq_len = mtf.Dimension('max_seq_len', max_seq_len) wte = mtf.get_variable(mesh, name='wte', shape=mtf.Shape([dim_num_tokens, dim]), dtype=tf.float32) wpe = mtf.get_variable(mesh, name='wpe', shape=mtf.Shape([seq_dim, dim]), dtype=tf.float32) x = mtf.gather(wte, x, dim_num_tokens) p = mtf.gather(wpe, mtf.range(mesh, seq_dim, dtype=tf.int32), dim_max_seq_len) x = x + p x = transformer(x, depth=depth, dim_head=dim_head, dim_features_head=dim_features_head, causal=causal) logits = linear(x, dim_num_tokens, scope='to_logits') return logits
def logits_fn(step_num, ids, states): """logits_fn for mtf.beam_search.beam_search().""" inputs_this_step = mtf.gather(ids, step_num - 1, length_dim) if self.attribute_embedding: attributes_this_step = mtf.gather(attributes, step_num - 1, length_dim) else: attributes_this_step = None context_incremental = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=step_num, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=step_num, inputs=inputs_this_step, encoder_inputs=encoder_inputs) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step, attributes=attributes_this_step, z=z) return mtf.to_float(logits), context_incremental.new_states
def call(self, context, x: mtf.Tensor) -> mtf.Tensor: """Call the layer.""" # Initialize Memory Keys and Values n_key_dim = mtf.Dimension("n_keys", self.n_keys) n_value_dim = mtf.Dimension("n_values", self.n_values) key_dim = mtf.Dimension("key", self.key_size // 2) value_dim = x.shape.dims[-1] head_dim = mtf.Dimension("n_heads", self.n_heads) product_dim = mtf.Dimension("product_key", 2) keys = mtf.get_variable( context.mesh, name="keys", shape=mtf.Shape([head_dim, product_dim, n_key_dim, key_dim]), dtype=context.variable_dtype) values = mtf.layers.embedding_weights( context.mesh, vocab_dim=n_value_dim, output_dim=value_dim, variable_dtype=context.variable_dtype, name="values") # Compute query new_dims = [head_dim, product_dim, key_dim] reduce_dims = x.shape.dims[-1:] query = mtf.layers.dense(x, new_dims, reduced_dims=reduce_dims, activation=None, use_bias=True, variable_dtype=context.variable_dtype, name="query") # [b, l, h, 2, k] # Note: We use layer norm instead of batch norm to normalize queries. # The main advantage is that layer norm works well with the codebase # whereas the implementation of batch norm requires handling of tf ops. query = mtf.layers.layer_norm(query, query.shape.dims[-1]) # Retrieve indices and scores scores, indices = self.get_indices(keys, query) # [b, l, h, k] scores = mtf.softmax(scores, reduced_dim=scores.shape.dims[-1]) top_values = mtf.gather(values, indices, n_value_dim) # [b, l, h, k, v] out_values = mtf.einsum( [top_values, scores], reduced_dims=scores.shape.dims[-2:]) # [b, l, v] return out_values
def embedding(self, x, name): embd_dim = self.dimensions["embed_dim"] vocab_dim = self.dimensions["final_vocab_dim"] with tf.variable_scope(name): wte = mtf.get_variable( x.mesh, "wte", mtf.Shape([vocab_dim, embd_dim]), initializer=tf.random_normal_initializer(stddev=0.02), master_dtype=self.variable_dtype.master_dtype, slice_dtype=self.variable_dtype.slice_dtype, activation_dtype=self.variable_dtype.activation_dtype) # Text embedding x = mtf.gather(wte, x, vocab_dim) embed_dropout = self.params.get("embed_dropout", 0) if embed_dropout > 0 and self.mode == "train": x = mtf.dropout(x, rate=embed_dropout, name="wte_dropout") return x
def get_masked_lm_output(self, positions, label_ids, label_weights): """Get loss and logits for the masked LM.""" input_tensor = self.get_sequence_output() output_weights = self.get_embedding_table() # [batch_size, num_position, hidden] input_tensor = mtf.gather(input_tensor, positions, self.seq_dim) with tf.variable_scope("cls/predictions"): # We apply one more non-linear transformation before the output layer. # This matrix is not used after pre-training. with tf.variable_scope("transform"): input_tensor = mtf.layers.dense( input_tensor, reduced_dims=[self.model_dim], new_dims=[self.model_dim], activation=get_activation(self.config.feedforward_intermediate_act), kernel_initializer=self.dense_initializer, use_bias=self.config.use_bias) input_tensor = self.normalize(input_tensor) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. output_bias = mtf.get_variable( input_tensor.mesh, name="output_bias", shape=[self.vocab_dim], initializer=tf.zeros_initializer()) logits = mtf.einsum([input_tensor, output_weights], reduced_dims=[self.model_dim]) + output_bias per_example_loss = mtf.layers.softmax_cross_entropy_with_logits( logits, label_ids, self.vocab_dim, z_loss=1e-4) # The `positions` tensor might be zero-padded (if the sequence is too # short to have the maximum number of predictions). The `label_weights` # tensor has a value of 1.0 for every real prediction and 0.0 for the # padding predictions. numerator = mtf.reduce_sum(label_weights * per_example_loss) denominator = mtf.reduce_sum(label_weights) + mtf.constant( input_tensor.mesh, 1e-5, dtype=tf.float32) loss = numerator / denominator return (loss, per_example_loss, logits)
def positional_embedding(self, x, name): with tf.variable_scope(name): # Positional embedding wpe = mtf.get_variable( x.mesh, "wpe", mtf.Shape([ self.dimensions["embed_seq_dim"], self.dimensions["embed_dim"] ]), initializer=tf.random_normal_initializer(stddev=0.01), master_dtype=self.variable_dtype.master_dtype, slice_dtype=self.variable_dtype.slice_dtype, activation_dtype=self.variable_dtype.activation_dtype) position_indices = mtf.range(x.mesh, self.dimensions["total_seq_dim"], tf.int64) if not \ self.is_incremental_inference else (self.context.position - 1) pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0]) embed_dropout = self.params.get("embed_dropout", 0) if embed_dropout > 0 and self.mode == "train": pos_emb = mtf.dropout(pos_emb, rate=embed_dropout, name="wte_dropout") x += pos_emb return x
def attn(x, scope, n_state, *, attention_type, params, bias, dim_seq, memory_length_dim, variable_dtype, context=None): # x :: [batch, seq, n_embd] x_shape, dim_batch, *_, dim_embd, mesh = x.shape, *x.shape, x.mesh # n_state is the same as config["n_embd"], which is also the same as dim_embd. assert n_state.size % params["n_head"] == 0 dim_heads = mtf.Dimension("heads", params["n_head"]) num_mem_kv = params.get("num_mem_kv", 0) use_num_mem_kv = num_mem_kv > 0 with tf.variable_scope(scope): # Compute attention inputs dim_kv = mtf.Dimension("features_per_head", params["n_embd"] // params["n_head"]) mtfparams = mtf.transformer.attention.attention_params_simple( x.mesh, io_dim=dim_embd, kv_dim=dim_kv, heads_dim=dim_heads, variable_dtype=variable_dtype ) q = mtfparams.compute_q(x) k = mtfparams.compute_k(x) v = mtfparams.compute_v(x) if is_incremental_inference(context): one_hot = mtf.one_hot(context.position - 1, dim_seq, dtype=variable_dtype.master_dtype) inv_one_hot = 1.0 - one_hot old_k, old_v = context.get_states(2) k = old_k * inv_one_hot + k * one_hot v = old_v * inv_one_hot + v * one_hot if exists(context): context.record_new_states([k, v]) with tf.variable_scope("attention"): if attention_type == "local": # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights. radius = params.get("local_attention_radius", 256) if is_incremental_inference(context): q *= one_hot a = mtf_transformer.attention.local_attention_1d( q, k, v, length_dim=k.shape[1], key_dim=dim_kv, value_dim=dim_kv, radius=radius, length_dim_num_splits=1, fully_autoregressive=params["causal"], attention_kwargs={}, ) if is_incremental_inference(context): a = mtf.gather(a, context.position - 1, dim_seq) elif attention_type == "global": # TODO: pass in fake context # Broadcast mask bias across batch and heads if exists(bias): if not is_incremental_inference(context): broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-2], bias.shape[-1]]) else: # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position bias = mtf.gather(bias, context.position - 1, dim_seq) broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-1]]) # memory key / values, from all-attention paper if use_num_mem_kv: k, v = memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh) k = mtf.replace_dimensions(k, k.shape[1], memory_length_dim) v = mtf.replace_dimensions(v, v.shape[1], memory_length_dim) attn_dropout_rate = params["attn_dropout"] if params["mode"] == "train" else 0 a = mtf_transformer.attention.attention( q, k, v, memory_length_dim=memory_length_dim, key_dim=dim_kv, value_dim=dim_kv, bias=broadcasted_bias, dropout_rate=attn_dropout_rate ) elif attention_type == "linear": linear_attn_fn = causal_linear_attention if params["causal"] else linear_attention a = linear_attn_fn(q, k, v) else: raise NotImplementedError("Unknown attention type {}!".format(attention_type)) with tf.variable_scope("compute_output"): a = mtfparams.compute_output(a, x_shape) with tf.variable_scope("compute_output_bias"): b = mtf.get_variable(x.mesh, "o_b", [dim_embd], initializer=tf.constant_initializer(0), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) a += b if params["mode"] == "train" and params["res_dropout"] > 0: a = mtf.dropout(a, rate=params["res_dropout"], name="res_dropout") return a
def get_log_softmax_value(self, log_softmax, index): """Returns the entry at index of the log_softmax along the vocab dim.""" return mtf.gather(log_softmax, index, dim=self._vocab_dim)
def _call_internal(self, context, inputs, targets=None): """Compute logits based on inputs (all positions in parallel). Also updates context if applicable. Args: context: a Context inputs: a Tensor targets: an optional Tensor Returns: logits: a Tensor with shape [<batch_dims>, length_dim, output_vocab_dim] """ mesh = inputs.mesh if "embedding" in context.shared_params: embedding_weights = context.shared_params["embedding"] else: embedding_weights = mtf.layers.embedding_weights( mesh, self.input_vocab_dim, self.model_dim, context.variable_dtype, name="embedding") x = mtf.gather(embedding_weights, inputs, self.input_vocab_dim) if "positional_embedding" in context.shared_params: pos_emb_var = context.shared_params["positional_embedding"] else: pos_emb_var = mtf.layers.embedding_weights( mesh, self.max_length_dim, self.model_dim, context.variable_dtype, "positional_embedding") if context.position_is_default: pos_emb = mtf.rename_dimension( mtf.slice(pos_emb_var, 0, context.length_dim.size, self.max_length_dim.name), self.max_length_dim.name, context.length_dim.name) else: pos_emb = mtf.gather( pos_emb_var, context.position, self.max_length_dim, output_shape=x.shape) x += pos_emb x = self.layer_stack.call(context, x) if self.output_vocab_dim is None: return x if self.shared_embedding_and_softmax_weights: logits = mtf.einsum( [x * (self.model_dim.size ** -0.5), embedding_weights], reduced_dims=[self.model_dim]) else: logits = mtf.layers.dense( x, self.output_vocab_dim, use_bias=False, variable_dtype=context.variable_dtype, name="logits") if targets is not None and context.losses is not None: off_value = self.label_smoothing / self.output_vocab_dim.size on_value = 1.0 - self.label_smoothing + off_value soft_targets = mtf.one_hot( targets, self.output_vocab_dim, dtype=context.activation_dtype, on_value=on_value, off_value=off_value) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.output_vocab_dim, z_loss=self.z_loss if context.train else 0.0) weights = mtf.layers.weights_nonzero( targets, dtype=context.activation_dtype) loss = mtf.reduce_mean(loss * weights) context.losses.append(loss) return logits
def __init__(self, config, is_training, input_ids, input_mask=None, token_type_ids=None, scope=None, mesh_shape="", layout=""): self.config = copy.deepcopy(config) del config if not is_training: self.config.layer_output_dropout_prob = 0.0 self.config.attention_probs_dropout_prob = 0.0 self.config.feedforward_intermediate_dropout_prob = 0.0 input_shape = input_ids.shape assert input_shape.ndims == 2 self._seq_dim = input_shape.dims[1] self._memory_seq_dim = mtf.Dimension("memory_seq", self.seq_dim.size) self._extra_losses = [] mesh = input_ids.mesh if token_type_ids is None: token_type_ids = mtf.zeros(mesh, input_shape, dtype=tf.int32) with tf.variable_scope(scope, default_name="bert"): with tf.variable_scope("embeddings"): # Perform embedding lookup on the word ids. self.embedding_table = mtf.get_variable( mesh, "word_embeddings", mtf.Shape([self.vocab_dim, self.model_dim]), initializer=self.embedding_initializer) self.word_embedding_output = mtf.gather( self.embedding_table, input_ids, self.vocab_dim) # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. self.embedding_output = self.word_embedding_output token_type_table = mtf.get_variable( mesh, "token_type_embeddings", mtf.Shape([self.token_type_vocab_dim, self.model_dim]), initializer=self.embedding_initializer) if token_type_ids is not None: self.embedding_output += mtf.gather( token_type_table, token_type_ids, self.token_type_vocab_dim) if self.config.position_signal == "embedding": full_position_table = mtf.get_variable( mesh, "position_embeddings", mtf.Shape([self.max_position_embeddings_dim, self.model_dim]), initializer=self.embedding_initializer) short_position_table = mtf.rename_dimension( mtf.slice(full_position_table, 0, self.seq_dim.size, self.max_position_embeddings_dim.name), self.max_position_embeddings_dim.name, self.seq_dim.name) self.embedding_output += short_position_table self.embedding_output = self.normalize(self.embedding_output) self.embedding_output = mtf.dropout( self.embedding_output, is_training, keep_prob=1.0 - self.config.layer_output_dropout_prob) with tf.variable_scope("encoder"): attention_biases = [] if input_mask: # [batch_dim, memory_seq_dim] attention_biases.append( (1.0 - mtf.to_float(mtf.replace_dimensions( input_mask, self.seq_dim, self.memory_seq_dim))) * -10000.0) if self.config.position_signal == "relative_attention_bias": buckets_dim = mtf.Dimension("buckets", 32) rp_bucket = _relative_position_bucket( mtf.range(mesh, self.memory_seq_dim, tf.int32) - mtf.range(mesh, self.seq_dim, tf.int32), num_buckets=buckets_dim.size) bias_var = mtf.get_variable( mesh, "relative_attention_bias", [self.num_heads_dim, buckets_dim], initializer=tf.zeros_initializer()) attention_biases.append(mtf.gather(bias_var, rp_bucket, buckets_dim)) attention_bias = mtf.add_n(attention_biases) prev_layer_output = self.embedding_output self.all_encoder_layers = [] for block_num in range(self.config.num_blocks): with tf.variable_scope("block_%d" % block_num): for layer_idx, layer_type in enumerate(self.config.block_layers): layer_name = layer_type count = self.config.block_layers[:layer_idx].count(layer_type) if count: layer_name += "_%d" % count with tf.variable_scope(layer_name): x = prev_layer_output if self.config.residual_structure == "direct": x = self.normalize(x) if layer_type == "attention": x = self.self_attention(x, attention_bias) elif layer_type == "feedforward": x = self.feedforward(x) elif layer_type == "moe": x = self.moe(x, layout, mesh_shape, input_mask, is_training) else: raise ValueError("unknown layer type " + layer_type) x = mtf.dropout( x, is_training, keep_prob=1.0 - self.config.layer_output_dropout_prob) layer_output = prev_layer_output + x if self.config.residual_structure == "original": layer_output = self.normalize(layer_output) prev_layer_output = layer_output self.all_encoder_layers.append(layer_output) self.sequence_output = prev_layer_output if self.config.residual_structure == "direct": self.sequence_output = self.normalize(self.sequence_output) # The "pooler" converts the encoded sequence tensor of shape # [batch_dim, seq_dim, hidden_size] to a tensor of shape # [batch_dim, hidden_size]. This is necessary for segment-level # (or segment-pair-level) classification tasks where we need a fixed # dimensional representation of the segment. with tf.variable_scope("pooler"): # We "pool" the model by simply taking the hidden state corresponding # to the first token. We assume that this has been pre-trained first_token_tensor = mtf.gather(self.sequence_output, 0, self.seq_dim) self.pooled_output = mtf.layers.dense( first_token_tensor, reduced_dims=[self.model_dim], new_dims=[self.model_dim], activation=mtf.tanh, kernel_initializer=self.dense_initializer, use_bias=self.config.use_bias)
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.activation_type # We assume fixed vocab size for targets targets = tf.to_int32(features["targets"]) # Image preprocessing, reshape into a 1D sequence and shift right. length = hparams.img_len*hparams.img_len*hparams.num_channels targets = tf.reshape(targets, [hparams.batch_size, length]) shifted_targets = common_layers.shift_right_2d(targets) # Declare all the dimensions batch_dim = mtf.Dimension("batch", hparams.batch_size) def import_to_batch_by_length(x, name): return mtf.import_tf_tensor( mesh, x, mtf.Shape([batch_dim, self.length_dim]), name=name) targets = import_to_batch_by_length(targets, "targets") shifted_targets = import_to_batch_by_length( shifted_targets, "shifted_targets") extra_losses = [] # Create targets content and position embeddings. # Create embedding var for targets and positions and do a gather. targets_embedding_var = mtf.get_variable( mesh, "targets_embedding", mtf.Shape([self.targets_vocab_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=activation_dtype) x = mtf.gather(targets_embedding_var, shifted_targets, self.targets_vocab_dim) # Add positional embeddings x += mtf.reshape(self.create_positional_emb_2d(targets), [self.length_dim, self.model_dim]) # If conditional and input is given, add the input embedding to the target. # TODO(nikip): Verify conditional. if self.has_input and not hparams.unconditional: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = import_to_batch_by_length(inputs, "inputs") # Input embeddings inputs_embedding_var = mtf.layers.embedding( mesh, "input_embedding", mtf.Shape([self.inputs_vocab_dim, self.model_dim]), activation_dtype=activation_dtype) inputs_emb = mtf.gather( inputs_embedding_var, inputs, self.inputs_vocab_dim) x += inputs_emb # Image Transformer Decoder # [ self attention - ffn - residual + dropout] x n if hparams.attention_type == "local1d_spatial": decoder_output = local_attention1d_spatial_decoder( x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) elif hparams.attention_type == "local2d_spatial": decoder_output = local_attention2d_spatial_decoder( x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) elif hparams.attention_type == "local1d": decoder_output = local_attention1d_masked_decoder( x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) else: raise ValueError("Invalid attention type.") # Calculate the logits and loss. logits = mtf.layers.dense( decoder_output, self.outputs_vocab_dim, name="logits") # Need a reshape for logits logits = mtf.reshape( logits, mtf.Shape([batch_dim, self.length_dim, self.outputs_vocab_dim])) soft_targets = mtf.one_hot( targets, self.outputs_vocab_dim, dtype=activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.outputs_vocab_dim) loss = mtf.reduce_mean(loss) for l in extra_losses: loss += l # Reshape logits to original target shape. logits = mtf.reshape( logits, mtf.Shape([batch_dim, self.rows_dim, self.orig_cols_dim, self.channels_dim, self.outputs_vocab_dim])) return logits, loss
def _mtf_model_fn(self, features, mesh): features = copy.copy(features) hparams = self._hparams targets = tf.to_int32(features["targets"]) if len(targets.get_shape()) > 2: tf.logging.info("targets = %s" % targets) targets = tf.squeeze(targets, [2, 3]) # pad targets to max_length def pad_to_max_length(x): extra_length = hparams.max_length - tf.shape(x)[1] x = tf.pad(x, [[0, 0], [0, extra_length]]) x = tf.reshape(x, [hparams.batch_size, hparams.max_length]) return x targets = pad_to_max_length(targets) for key in [ "targets_segmentation", "targets_position", "inputs_segmentation", "inputs_position" ]: if key in features: features[key] = pad_to_max_length(features[key]) shifted_targets = common_layers.shift_right_2d(targets) targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams) shifted_targets = self._import_to_batch_by_length( shifted_targets, "shifted_targets", mesh, hparams) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = self._import_to_batch_by_length( features["targets_segmentation"], "targets_segmentation", mesh, hparams) targets_position = self._import_to_batch_by_length( features["targets_position"], "targets_position", mesh, hparams) decoder_self_attention_mask = ( mtf.layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) + mtf.layers.attention_mask_same_segment( targets_segmentation, dtype=self.activation_dtype)) else: targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) def layer_prepostprocess_dropout(x): return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape(self.batch_dims + [self.model_dim])) extra_losses = [] (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "decoder": encoder_output = None encoder_decoder_attention_mask = None else: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = pad_to_max_length(inputs) inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = self._import_to_batch_by_length( features["inputs_segmentation"], "inputs_segmentation", mesh, hparams) inputs_position = self._import_to_batch_by_length( features["inputs_position"], "inputs_position", mesh, hparams) encoder_self_attention_mask = ( mtf.layers.attention_mask_same_segment( inputs_segmentation, dtype=self.activation_dtype)) else: inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) encoder_self_attention_mask = ( mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.gather(positional_embedding_var, inputs_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("encoder"): x = self._layer_stack( x, hparams.encoder_layers, self_attention_mask=encoder_self_attention_mask, losses=extra_losses) if hparams.transformer_type == "encdec": if "inputs_segmentation" in features: encoder_decoder_attention_mask = ( mtf.layers.attention_mask_same_segment( targets_segmentation, inputs_segmentation, dtype=self.activation_dtype)) else: encoder_decoder_attention_mask = encoder_self_attention_mask encoder_output = mtf.rename_dimension(x, self.length_dim.name, self.memory_length_dim.name) if hparams.transformer_type != "encoder": # DECODER x = (mtf.gather(targets_embedding_var, shifted_targets, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, targets_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("decoder"): x = self._layer_stack( x, hparams.decoder_layers, encoder_output=encoder_output, self_attention_mask=decoder_self_attention_mask, encdec_attention_mask=encoder_decoder_attention_mask, losses=extra_losses) logits = mtf.matmul(x, softmax_var) if hparams.mode == tf.estimator.ModeKeys.TRAIN: logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2) off_value = hparams.label_smoothing / self._targets_vocab_size on_value = 1.0 - hparams.label_smoothing + off_value soft_targets = mtf.one_hot(targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value, dtype=self.activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.targets_vocab_dim) weights = mtf.layers.weights_nonzero(targets, dtype=self.activation_dtype) loss = mtf.reduce_mean(loss * weights) for l in extra_losses: loss += l logits = mtf.to_float(logits) # combine batch dims if len(self.batch_dims) > 1: combined_batch_dim = mtf.Dimension(self.batch_dims[0].name, mtf.Shape(self.batch_dims).size) logits = mtf.reshape(logits, [combined_batch_dim] + logits.shape.dims[-2:]) return logits, loss
def _sample(self, features, mesh): hparams = self._hparams (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "encdec": inputs = features["inputs"] while len(inputs.shape.as_list()) > 2: inputs = tf.squeeze(inputs, axis=2) actual_batch_size = tf.shape(inputs)[0] actual_length = tf.shape(inputs)[1] inputs = tf.pad(inputs, [[0, hparams.batch_size - actual_batch_size], [0, hparams.max_length - actual_length]]) inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.reshape(positional_embedding_var, mtf.Shape([self.length_dim, self.model_dim]))) encoder_attention_mask = (mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) with tf.variable_scope("encoder"): x = self._layer_stack( x, hparams.encoder_layers, self_attention_mask=encoder_attention_mask) encoder_output = mtf.rename_dimension(x, self.length_dim.name, self.memory_length_dim.name) encdec_tensors = [] for layer_num, layer_type in enumerate(hparams.decoder_layers): if layer_type == "enc_att": with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num): q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars( mesh, self.heads_dim, self.model_dim, self.kv_dim, self.master_dtype, self.slice_dtype, self.activation_dtype) k = mtf.einsum([encoder_output, k_var], mtf.Shape(self.batch_dims + [ self.heads_dim, self.memory_length_dim, self.kv_dim ])) v = mtf.einsum([encoder_output, v_var], mtf.Shape(self.batch_dims + [ self.heads_dim, self.memory_length_dim, self.kv_dim ])) encdec_tensors.append((q_var, o_var, k, v)) else: encdec_tensors.append(None) partial_targets = None elif hparams.transformer_type == "decoder": encdec_tensors = None encoder_output = None encoder_attention_mask = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd( partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, hparams.max_length - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh, hparams) else: raise ValueError("hparams.model_type = %s not yet supported" % hparams.transformer_type) local_attention_window = mtf.Dimension( "local_attention_window", hparams.local_attention_window_size) if hparams.beam_size == 1: ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) kv_shape = mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape( self.batch_dims + [self.heads_dim, local_attention_window, self.kv_dim]) else: beam_dim = mtf.Dimension("beam", hparams.beam_size) ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [ beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim ]) local_kv_shape = mtf.Shape(self.batch_dims + [ beam_dim, self.heads_dim, local_attention_window, self.kv_dim ]) initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) initial_states = [] for layer in hparams.decoder_layers: if layer == "att": initial_states.extend( [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2) elif layer == "local_att": initial_states.extend([ mtf.zeros( mesh, local_kv_shape, dtype=self.activation_dtype) ] * 2) def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_states = self._layer_stack( x, hparams.decoder_layers, encdec_attention_mask=encoder_attention_mask, step_num=step_num, encdec_tensors=encdec_tensors, states=states) logits = mtf.matmul(x, softmax_var) return logits, new_states if hparams.beam_size == 1: temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) return mtf.beam_search.greedy_decode(logits_fn, initial_ids, temperature=temperature, initial_states=initial_states, forced_ids=partial_targets, use_tpu=hparams.use_tpu) else: if hparams.transformer_type == "encdec": input_length = mtf.reduce_sum(mtf.to_float( mtf.cast(inputs, tf.bool)), reduced_dim=self.length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * hparams.decode_length_multiplier + hparams.decode_length_constant, tf.int32) else: decode_length = None beams, unused_scores = mtf.beam_search.beam_search( logits_fn, initial_ids, hparams.alpha, states=initial_states, decode_length=decode_length, use_tpu=hparams.use_tpu, dtype=self.activation_dtype) return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
def synthetic_attention(q, k, v, memory_length_dim, key_dim, value_dim, bias=None, dropout_rate=0.0, dropout_broadcast_dims=None, extra_logit=None, synthesize=True, synthesize_mode="random_plus_alpha", factorized_dim=16, max_length=512, context=None): """Synthetic Attention from Synthesizers (https://arxiv.org/abs/2005.00743). key_dim is a Dimension representing the channels in the queries and keys value_dim is a Dimension representing the channels in values memory_length_dim is a Dimension representing the different key/value pairs. Dimensions of q: other_query_dims + {key_dim} Dimensions of k: other_memory_dims + {memory_length_dim, key_dim} Dimensions of v: other_memory_dims + {memory_length_dim, value_dim} other_memory_dims is a subset of other_query_dims Typically, other_query_dims={batch, heads, length} Typically, other_memory_dims={batch, heads} Args: q: a Tensor k: a Tensor v: a Tensor memory_length_dim: a Dimension key_dim: a Dimension value_dim: a Dimension bias: a Tensor to be added into the attention logits. dropout_rate: a float. dropout_broadcast_dims: an optional list of mtf.Dimension extra_logit: an optional scalar or tensor synthesize: flag to use synthetic attention or not synthesize_mode: which variant of synthesizer to use factorized_dim: factorized dim for synthesizers max_length: max length of input sequence context: context since we need context mode Returns: Tensor with shape q.shape - key_dim + value_dim """ if synthesize: num_heads = v.shape.get_dim_by_name("heads") tf.logging.info("Using synthesizer") if synthesize_mode == "random": tf.logging.info("Using Random Synthesizers") r_shape = mtf.Shape([mtf.Dimension("length", max_length), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", max_length)]) r = mtf.get_variable(context.mesh, "R", r_shape, initializer=None, dtype=context.variable_dtype) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length")) else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, "length") logits = r r_shape = logits.shape elif synthesize_mode == "factorized": tf.logging.info("Using Factorized Random Synthesizers") k = factorized_dim r1_shape = mtf.Shape([mtf.Dimension("tmp", k), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r2_shape = mtf.Shape([mtf.Dimension("tmp", k), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r_shape = mtf.Shape([mtf.Dimension("length", 512), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r1 = mtf.get_variable(context.mesh, "R1", r1_shape, initializer=None, dtype=context.variable_dtype) r2 = mtf.get_variable(context.mesh, "R2", r2_shape, initializer=None, dtype=context.variable_dtype) r = mtf.einsum([r1, r2], r_shape) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length")) else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, "length") logits = r elif synthesize_mode == "dense_minus": # Dense Synthesizer Model tmp_dim = mtf.Dimension("memory_length", max_length) logits = mtf.layers.dense(mtf.relu(q), [tmp_dim], use_bias=False, name="pi", reduced_dims=[key_dim], variable_dtype=None) logits = mtf.slice(logits, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": pass else: length_dim = q.shape.get_dim_by_name("length") logits = mtf.slice(logits, 0, length_dim.size, "length") elif synthesize_mode == "random_plus_alpha" or \ synthesize_mode == "random_plus": # Mixture Random Synthesizer with learnable Alpha tf.logging.info("Using Random Plus Alpha") logits = mtf.einsum([q, k], reduced_dims=[key_dim]) num_heads = logits.shape.get_dim_by_name("heads") r_shape = mtf.Shape([mtf.Dimension("length", 512), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r = mtf.get_variable(context.mesh, "R", r_shape, initializer=None, dtype=context.variable_dtype) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length")) else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, length_dim.name) if "alpha" in synthesize_mode: alpha = mtf.get_variable(context.mesh, "alpha", mtf.Shape([mtf.Dimension("alpha", 1)]), initializer=tf.zeros_initializer(), dtype=context.variable_dtype) alpha = mtf.sigmoid(alpha) logits = ((1-alpha) * logits) + (alpha * r) else: logits = logits + r elif synthesize_mode == "dense_plus_alpha" or \ synthesize_mode == "dense_plus": # Mixture Dense Synthesizer with learnable alpha tf.logging.info("Using Dense Plus Alpha Scaling") logits = mtf.einsum([q, k], reduced_dims=[key_dim]) tmp_dim = mtf.Dimension("memory_length", 512) r = mtf.layers.dense(mtf.relu(q), [tmp_dim], use_bias=False, name="pi", reduced_dims=[key_dim], variable_dtype=None) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": pass else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, "length") if "alpha" in synthesize_mode: alpha = mtf.get_variable(context.mesh, "alpha", mtf.Shape([mtf.Dimension("alpha", 1)]), initializer=tf.zeros_initializer(), dtype=context.variable_dtype) alpha = mtf.sigmoid(alpha) logits = ((1-alpha) * logits) + (alpha * r) else: logits = logits + r if bias is not None: logits += bias weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit) weights = mtf.dropout( weights, context.train, 1.0 - dropout_rate, noise_shape=weights.shape - dropout_broadcast_dims) if synthesize and "plus" not in synthesize_mode: if synthesize_mode == "dense_minus": outputs_shape = mtf.Shape(q.shape.dims[:-1] + [value_dim]) else: outputs_shape = mtf.Shape(q.shape.dims[:-1] + [num_heads, value_dim]) else: outputs_shape = q.shape - [key_dim] + value_dim outputs = mtf.einsum([weights, v], outputs_shape) return outputs
def compute_bias(self, context, memory_position, x): """Compute attention bias. Args: context: a transformer.Context memory_position: an int32 tensor containing memory_length dimension. x: a Tensor - the query antecedent - required for relative attention Returns: a Tensor or None """ min_relative_position = self.min_relative_position(context) max_relative_position = self.max_relative_position(context) # we can often cache the result of this function between similar layers can_cache = (self.relative_attention_type is None or self.relative_attention_type == "bias_shared") if can_cache: cache_key = ("self_attention_mask", min_relative_position, max_relative_position, self.relative_attention_type, self.num_heads) if cache_key in context.cache: return context.cache[cache_key] biases = [] relative_position = memory_position - context.position if min_relative_position is not None: visible = mtf.greater_equal(relative_position, min_relative_position) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) if max_relative_position is not None: visible = mtf.less_equal(relative_position, max_relative_position) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) if context.read_priority is not None: visible = mtf.greater_equal( context.read_priority, mtf.layers.rename_length_to_memory_length( context.write_priority)) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) sequence_id = None # Subsequence id should only be set if we are in the decoder and have # multiple targets per input. This will allow each sub-target to only attend # to itself. if isinstance(context.subsequence_id, mtf.Tensor): sequence_id = context.subsequence_id elif isinstance(context.sequence_id, mtf.Tensor): sequence_id = context.sequence_id if (sequence_id is not None and context.length_dim in sequence_id.shape): visible = mtf.equal( sequence_id, self.rename_length_to_memory_length(sequence_id, context)) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) if self.relative_attention_type is not None: buckets_dim = mtf.Dimension("buckets", self.relative_attention_num_buckets) heads_dim = mtf.Dimension("heads", self.num_heads) bidirectional = not context.model.fully_autoregressive rp_bucket = _relative_position_bucket(relative_position, bidirectional=bidirectional, num_buckets=buckets_dim.size) if (self.relative_attention_type == "bias" or self.relative_attention_type == "bias_shared"): values = mtf.get_variable(context.mesh, "relative_attention_bias", [heads_dim, buckets_dim], dtype=context.variable_dtype) elif self.relative_attention_type == "contextual": values = layers.dense(x, [buckets_dim, heads_dim], variable_dtype=context.variable_dtype, name="relative_attention_contextual") else: raise ValueError( "unrecognized relative_attention_type \"%s\"" % self.relative_attention_type) biases.append(mtf.gather(values, rp_bucket, buckets_dim)) ret = mtf.add_n(biases) if biases else None if can_cache: context.cache[cache_key] = ret return ret
def ids_to_embedding(self, ids): tmp = mtf.gather(self._factor1, ids, self._vocab_dim) return mtf.einsum([tmp, self._factor2], reduced_dims=[self._inner_dim])
def _mtf_model_fn(self, features, mesh): self._original_features = features features = copy.copy(features) hparams = self._hparams extra_losses = [] targets = tf.to_int32(features["targets"]) mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN) is_training = mode == tf.estimator.ModeKeys.TRAIN if len(targets.get_shape()) > 2: tf.logging.info("targets = %s" % targets) targets = tf.squeeze(targets, [2, 3]) # pad targets to max_length def pad_to_max_length(x): extra_length = hparams.max_length - tf.shape(x)[1] x = tf.pad(x, [[0, 0], [0, extra_length]]) x = tf.reshape(x, [hparams.batch_size, hparams.max_length]) return x targets = pad_to_max_length(targets) targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams) for key in ["targets_segmentation", "targets_position", "inputs_segmentation", "inputs_position"]: if key in features: features[key] = pad_to_max_length(features[key]) if hparams.decoder_type == "autoregressive": shifted_targets = mtf.shift( targets, offset=1, dim=self.length_dim, wrap=False) elif hparams.decoder_type == "denoising": shifted_targets = self._noisy_targets(targets, extra_losses) else: raise ValueError( "unknown hparams.decoder_type = %s" % hparams.decoder_type) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = self._import_to_batch_by_length( features["targets_segmentation"], "targets_segmentation", mesh, hparams) targets_position = self._import_to_batch_by_length( features["targets_position"], "targets_position", mesh, hparams) decoder_self_attention_mask = mtf.layers.attention_mask_same_segment( targets_segmentation, dtype=self.activation_dtype) if hparams.decoder_type == "autoregressive": decoder_self_attention_mask += mtf.layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) else: targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) if hparams.decoder_type == "autoregressive": decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) else: decoder_self_attention_mask = None def layer_prepostprocess_dropout(x): return mtf.dropout( x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape(self.batch_dims + [self.model_dim])) (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "decoder": encoder_output = None encoder_decoder_attention_mask = None else: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = pad_to_max_length(inputs) inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = self._import_to_batch_by_length( features["inputs_segmentation"], "inputs_segmentation", mesh, hparams) inputs_position = self._import_to_batch_by_length( features["inputs_position"], "inputs_position", mesh, hparams) encoder_self_attention_mask = ( mtf.layers.attention_mask_same_segment( inputs_segmentation, dtype=self.activation_dtype)) else: inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) encoder_self_attention_mask = ( mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.gather(positional_embedding_var, inputs_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.encoder_layers, self_attention_mask=encoder_self_attention_mask, losses=extra_losses) if hparams.transformer_type == "encdec": if "inputs_segmentation" in features: encoder_decoder_attention_mask = ( mtf.layers.attention_mask_same_segment( targets_segmentation, inputs_segmentation, dtype=self.activation_dtype)) else: encoder_decoder_attention_mask = encoder_self_attention_mask encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) if hparams.transformer_type != "encoder": # DECODER x = (mtf.gather( targets_embedding_var, shifted_targets, self.targets_vocab_dim) + mtf.gather( positional_embedding_var, targets_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("decoder"): x = self._layer_stack( x, hparams.decoder_layers, encoder_output=encoder_output, self_attention_mask=decoder_self_attention_mask, encdec_attention_mask=encoder_decoder_attention_mask, losses=extra_losses) if (hparams.reshape_logits_hack and hparams.mode == tf.estimator.ModeKeys.TRAIN): # For some reason, the logits computation is extremely slow on TPU # in some cases where the batch size per core is 1. Reshape the logits # and the targets to double the batch size and halve the length. # TODO(noam): file a bug. old_dims = self.batch_dims + [self.length_dim] new_dims = self.batch_dims[:-1] + [ mtf.Dimension(self.batch_dims[-1].name, self.batch_dims[-1].size * 2), mtf.Dimension(self.length_dim.name, self.length_dim.size // 2)] x = mtf.reshape(x, new_dims + [self.model_dim]) targets = mtf.reshape(targets, new_dims) logits = mtf.matmul(x, softmax_var) if hparams.mode == tf.estimator.ModeKeys.TRAIN: logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2) off_value = hparams.label_smoothing / self._targets_vocab_size on_value = 1.0 - hparams.label_smoothing + off_value soft_targets = mtf.one_hot( targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value, dtype=self.activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.targets_vocab_dim) weights = mtf.layers.weights_nonzero(targets, dtype=self.activation_dtype) loss = mtf.reduce_mean(loss * weights) for l in extra_losses: loss += l if (hparams.reshape_logits_hack and hparams.mode == tf.estimator.ModeKeys.TRAIN): logits = mtf.reshape(logits, old_dims + [self.targets_vocab_dim]) logits = mtf.to_float(logits) return logits, loss
def model(mtf_features, other_features, params, mesh, variable_dtype, context=None): """A GPT style model implemented in mesh tensorflow.""" x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim = parse_inputs(mtf_features, other_features) if is_incremental_inference(context): # reshape inputs if in inference mode x = mtf.gather(x, context.position - 1, sequence_dim) x = mtf.reshape(x, [batch_dim]) use_axial_pos_emb = params["axial_pos_emb"] is not None if not use_axial_pos_emb: # Use standard position encoding wpe = mtf.get_variable(mesh, "wpe", mtf.Shape([embed_sequence_dim, embd_dim]), initializer=tf.random_normal_initializer(stddev=0.01), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) else: wpe = axial_positional_emb(embd_dim, mesh, params, variable_dtype) # Text encoding wte = mtf.get_variable(mesh, "wte", mtf.Shape([vocab_dim, embd_dim]), initializer=tf.random_normal_initializer(stddev=0.02), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) with tf.variable_scope("token_embd"): # Text embedding h = mtf.gather(wte, x, vocab_dim) if params["embed_dropout"] > 0 and params["mode"] == "train": h = mtf.dropout(h, rate=params["embed_dropout"], name="wte_dropout") with tf.variable_scope("pos_embd"): # Positional embedding position_indices = mtf.range(mesh, sequence_dim, tf.int64) if not is_incremental_inference(context) else ( context.position - 1) pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0]) if params["embed_dropout"] > 0 and params["mode"] == "train": pos_emb = mtf.dropout(pos_emb, rate=params["embed_dropout"], name="wte_dropout") h += pos_emb aux_losses = 0 # instantiate auxiliary losses (for MOE models) for layer in range(params["n_layer"]): # attn blocks share_parameters = exists(params["share_parameters"]) and params["share_parameters"] == True block_scope = f"h{layer}" if not share_parameters else "" block_fn = block(params=params, scope=block_scope, layer_num=layer, bias=other_features["attn_bias"], sequence_dim=sequence_dim, memory_length_dim=other_features["memory_length_dim"], variable_dtype=variable_dtype, context=context) # If true and in train mode, enable gradient checkpointing recompute_grad = params["recompute_grad"] and (params["mode"] == "train") == True h, loss = block_fn(h) if not recompute_grad else mtf.recompute_grad(block_fn, [h]) aux_losses += loss no_weight_tie_emb = params["no_weight_tie"] == True if no_weight_tie_emb: with tf.variable_scope("wte_final_linear"): logits = linear(h, "linear_out", vocab_dim, variable_dtype=variable_dtype, params=params) else: # Layer normalize & affine transform h = layer_norm(h, "ln_f", variable_dtype=variable_dtype) seq_dim = sequence_dim if not is_incremental_inference(context) else mtf.Dimension("sequence", 1) with tf.variable_scope("wte_final_einsum"): # Equivalent to tf.matmul logits = mtf.einsum([h, wte], output_shape=[batch_dim, seq_dim, vocab_dim]) if params["mode"] in ["train", "eval"]: labels = mtf_features["labels"] z_loss = params.get("z_loss", 1e-4) # an auxiliary loss used to stabilize mtf xentropy # Go to full precision for the logits logits = mtf.cast(logits, tf.float32) use_entmax_loss = params.get("entmax_loss", False) loss_fn = mtf.layers.softmax_cross_entropy_with_logits if not use_entmax_loss else entmax_cross_entropy_with_logits with tf.variable_scope("xentropy_final"): loss_batch = loss_fn(logits=logits, targets=labels, vocab_dim=logits.shape[-1], z_loss=z_loss) # For non-autoregressive models (masked language modeling training) # Make sure labels with padding tokens are not counted in the loss if not params["causal"]: padding_id = params.get("padding_id", 0) loss_batch = mtf.where(mtf.not_equal(labels, padding_id), loss_batch, mtf.zeros_like(loss_batch)) with tf.variable_scope("reduce_mean_final"): loss = mtf.reduce_mean(loss_batch) loss += aux_losses # Add on auxiliary losses (currently only used for MoE) loss /= params["num_microbatches"] # Convert to train dtype loss = mtf.cast(loss, variable_dtype.slice_dtype) else: loss = None loss_batch = None # Cast back to checkpoint dtype logits = mtf.cast(logits, variable_dtype.master_dtype) return logits, loss, loss_batch
def beam_search(self, inputs, decode_length, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, alpha=0.6, shared_params=None, encoder_layer_outputs=None): """Beam search. Args: inputs: an int32 zero-Tensor with shape [<batch_dims>, beam_dim, length_dim]. decode_length: an int32 mtf scalar. Maximum decode length. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor alpha: a floating point value (length bonus) shared_params: an optional dictionary encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ if not self.autoregressive: raise ValueError("must be autoregressive") batch_dims = inputs.shape.dims[:-2] if len(batch_dims) != 1: raise NotImplementedError( "beam search supports exactly one batch dimension.") beam_dim = inputs.shape.dims[-2] length_dim = inputs.shape.dims[-1] initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None context_first_part = Context( mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="first_part", autoregressive=self.autoregressive, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs) del logits # There are no partial targets. # Replace initial states by zeros to avoid computing them. initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states] constant_states = context_first_part.constant_states def logits_fn(step_num, ids, states): """logits_fn for mtf.beam_search.beam_search().""" context_incremental = Context( mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="incremental", autoregressive=self.autoregressive, position=step_num, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) inputs_this_step = mtf.gather(ids, step_num - 1, length_dim) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step) return mtf.to_float(logits), context_incremental.new_states beams, unused_scores = mtf.beam_search.beam_search( logits_fn, inputs, alpha, states=initial_states, decode_length=decode_length, use_tpu=True, dtype=tf.float32, mesh_shape=self.mesh_shape, layout=self.layout) return mtf.gather( beams, mtf.constant(inputs.mesh, 0, dtype=tf.int32), beam_dim)
def _sample(self, features, mesh): hparams = self._hparams (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "encdec": inputs = features["inputs"] while len(inputs.shape.as_list()) > 2: inputs = tf.squeeze(inputs, axis=2) actual_batch_size = tf.shape(inputs)[0] actual_length = tf.shape(inputs)[1] inputs = tf.pad( inputs, [[0, hparams.batch_size - actual_batch_size], [0, hparams.max_length - actual_length]]) inputs = self._import_to_batch_by_length( inputs, "inputs", mesh, hparams) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.reshape(positional_embedding_var, mtf.Shape([self.length_dim, self.model_dim]))) encoder_attention_mask = ( mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.encoder_layers, self_attention_mask=encoder_attention_mask) encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) encdec_tensors = [] for layer_num, layer_type in enumerate(hparams.decoder_layers): if layer_type == "enc_att": with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num): q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars( mesh, self.heads_dim, self.model_dim, self.kv_dim, self.master_dtype, self.slice_dtype, self.activation_dtype) k = mtf.einsum( [encoder_output, k_var], mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim])) v = mtf.einsum( [encoder_output, v_var], mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim])) encdec_tensors.append((q_var, o_var, k, v)) else: encdec_tensors.append(None) partial_targets = None elif hparams.transformer_type == "decoder": encdec_tensors = None encoder_output = None encoder_attention_mask = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, hparams.max_length - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh, hparams) else: raise ValueError( "hparams.model_type = %s not yet supported" % hparams.transformer_type) local_attention_window = mtf.Dimension( "local_attention_window", hparams.local_attention_window_size) if hparams.beam_size == 1: ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape(self.batch_dims + [self.heads_dim, local_attention_window, self.kv_dim]) else: beam_dim = mtf.Dimension("beam", hparams.beam_size) ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape(self.batch_dims + [beam_dim, self.heads_dim, local_attention_window, self.kv_dim]) initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) initial_states = [] for layer in hparams.decoder_layers: if layer == "att": initial_states.extend( [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2) elif layer == "local_att": initial_states.extend( [mtf.zeros(mesh, local_kv_shape, dtype=self.activation_dtype)] * 2) def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_states = self._layer_stack( x, hparams.decoder_layers, encdec_attention_mask=encoder_attention_mask, step_num=step_num, encdec_tensors=encdec_tensors, states=states) logits = mtf.matmul(x, softmax_var) return logits, new_states if hparams.beam_size == 1: temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) return mtf.beam_search.greedy_decode( logits_fn, initial_ids, temperature=temperature, initial_states=initial_states, forced_ids=partial_targets, use_tpu=hparams.use_tpu) else: if hparams.transformer_type == "encdec": input_length = mtf.reduce_sum( mtf.to_float(mtf.cast(inputs, tf.bool)), reduced_dim=self.length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * hparams.decode_length_multiplier + hparams.decode_length_constant, tf.int32) else: decode_length = None beams, unused_scores = mtf.beam_search.beam_search( logits_fn, initial_ids, hparams.alpha, states=initial_states, decode_length=decode_length, use_tpu=hparams.use_tpu, dtype=self.activation_dtype) return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.activation_type # We assume fixed vocab size for targets targets = tf.to_int32(features["targets"]) # Image preprocessing, reshape into a 1D sequence and shift right. length = hparams.img_len * hparams.img_len * hparams.num_channels targets = tf.reshape(targets, [hparams.batch_size, length]) shifted_targets = common_layers.shift_right_2d(targets) # Declare all the dimensions batch_dim = mtf.Dimension("batch", hparams.batch_size) def import_to_batch_by_length(x, name): return mtf.import_tf_tensor(mesh, x, mtf.Shape([batch_dim, self.length_dim]), name=name) targets = import_to_batch_by_length(targets, "targets") shifted_targets = import_to_batch_by_length(shifted_targets, "shifted_targets") extra_losses = [] # Create targets content and position embeddings. # Create embedding var for targets and positions and do a gather. targets_embedding_var = mtf.get_variable( mesh, "targets_embedding", mtf.Shape([self.targets_vocab_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=activation_dtype) x = mtf.gather(targets_embedding_var, shifted_targets, self.targets_vocab_dim) # Add positional embeddings x += mtf.reshape(self.create_positional_emb_2d(targets), [self.length_dim, self.model_dim]) # If conditional and input is given, add the input embedding to the target. # TODO(nikip): Verify conditional. if self.has_input and not hparams.unconditional: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = import_to_batch_by_length(inputs, "inputs") # Input embeddings inputs_embedding_var = mtf.layers.embedding( mesh, "input_embedding", mtf.Shape([self.inputs_vocab_dim, self.model_dim]), activation_dtype=activation_dtype) inputs_emb = mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) x += inputs_emb # Image Transformer Decoder # [ self attention - ffn - residual + dropout] x n if hparams.attention_type == "local1d_spatial": decoder_output = local_attention1d_spatial_decoder( x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) elif hparams.attention_type == "local2d_spatial": decoder_output = local_attention2d_spatial_decoder( x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) elif hparams.attention_type == "local1d": decoder_output = local_attention1d_masked_decoder( x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) else: raise ValueError("Invalid attention type.") # Calculate the logits and loss. logits = mtf.layers.dense(decoder_output, self.outputs_vocab_dim, name="logits") # Need a reshape for logits logits = mtf.reshape( logits, mtf.Shape([batch_dim, self.length_dim, self.outputs_vocab_dim])) soft_targets = mtf.one_hot(targets, self.outputs_vocab_dim, dtype=activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.outputs_vocab_dim) loss = mtf.reduce_mean(loss) for l in extra_losses: loss += l # Reshape logits to original target shape. logits = mtf.reshape( logits, mtf.Shape([ batch_dim, self.rows_dim, self.orig_cols_dim, self.channels_dim, self.outputs_vocab_dim ])) return logits, loss
def _mtf_model_fn(self, features, mesh): self._original_features = features features = copy.copy(features) hparams = self._hparams extra_losses = [] targets = tf.to_int32(features["targets"]) if len(targets.get_shape()) > 2: tf.logging.info("targets = %s" % targets) targets = tf.squeeze(targets, [2, 3]) # pad targets to max_length def pad_to_max_length(x): extra_length = hparams.max_length - tf.shape(x)[1] x = tf.pad(x, [[0, 0], [0, extra_length]]) x = tf.reshape(x, [hparams.batch_size, hparams.max_length]) return x targets = pad_to_max_length(targets) targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams) for key in ["targets_segmentation", "targets_position", "inputs_segmentation", "inputs_position"]: if key in features: features[key] = pad_to_max_length(features[key]) if hparams.decoder_type == "autoregressive": shifted_targets = mtf.shift( targets, offset=1, dim=self.length_dim, wrap=False) elif hparams.decoder_type == "denoising": shifted_targets = self._noisy_targets(targets, extra_losses) else: raise ValueError( "unknown hparams.decoder_type = %s" % hparams.decoder_type) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = self._import_to_batch_by_length( features["targets_segmentation"], "targets_segmentation", mesh, hparams) targets_position = self._import_to_batch_by_length( features["targets_position"], "targets_position", mesh, hparams) decoder_self_attention_mask = mtf.layers.attention_mask_same_segment( targets_segmentation, dtype=self.activation_dtype) if hparams.decoder_type == "autoregressive": decoder_self_attention_mask += mtf.layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) else: targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) if hparams.decoder_type == "autoregressive": decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) else: decoder_self_attention_mask = None def layer_prepostprocess_dropout(x): return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape(self.batch_dims + [self.model_dim])) (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "decoder": encoder_output = None encoder_decoder_attention_mask = None else: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = pad_to_max_length(inputs) inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = self._import_to_batch_by_length( features["inputs_segmentation"], "inputs_segmentation", mesh, hparams) inputs_position = self._import_to_batch_by_length( features["inputs_position"], "inputs_position", mesh, hparams) encoder_self_attention_mask = ( mtf.layers.attention_mask_same_segment( inputs_segmentation, dtype=self.activation_dtype)) else: inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) encoder_self_attention_mask = ( mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.gather(positional_embedding_var, inputs_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.encoder_layers, self_attention_mask=encoder_self_attention_mask, losses=extra_losses) if hparams.transformer_type == "encdec": if "inputs_segmentation" in features: encoder_decoder_attention_mask = ( mtf.layers.attention_mask_same_segment( targets_segmentation, inputs_segmentation, dtype=self.activation_dtype)) else: encoder_decoder_attention_mask = encoder_self_attention_mask encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) if hparams.transformer_type != "encoder": # DECODER x = (mtf.gather( targets_embedding_var, shifted_targets, self.targets_vocab_dim) + mtf.gather( positional_embedding_var, targets_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("decoder"): x = self._layer_stack( x, hparams.decoder_layers, encoder_output=encoder_output, self_attention_mask=decoder_self_attention_mask, encdec_attention_mask=encoder_decoder_attention_mask, losses=extra_losses) if (hparams.reshape_logits_hack and hparams.mode == tf.estimator.ModeKeys.TRAIN): # For some reason, the logits computation is extremely slow on TPU # in some cases where the batch size per core is 1. Reshape the logits # and the targets to double the batch size and halve the length. # TODO(noam): file a bug. old_dims = self.batch_dims + [self.length_dim] new_dims = self.batch_dims[:-1] + [ mtf.Dimension(self.batch_dims[-1].name, self.batch_dims[-1].size * 2), mtf.Dimension(self.length_dim.name, self.length_dim.size // 2)] x = mtf.reshape(x, new_dims + [self.model_dim]) targets = mtf.reshape(targets, new_dims) logits = mtf.matmul(x, softmax_var) if hparams.mode == tf.estimator.ModeKeys.TRAIN: logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2) off_value = hparams.label_smoothing / self._targets_vocab_size on_value = 1.0 - hparams.label_smoothing + off_value soft_targets = mtf.one_hot( targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value, dtype=self.activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.targets_vocab_dim) weights = mtf.layers.weights_nonzero(targets, dtype=self.activation_dtype) loss = mtf.reduce_mean(loss * weights) for l in extra_losses: loss += l if (hparams.reshape_logits_hack and hparams.mode == tf.estimator.ModeKeys.TRAIN): logits = mtf.reshape(logits, old_dims + [self.targets_vocab_dim]) logits = mtf.to_float(logits) return logits, loss