def create_positional_emb_2d(self, targets): """Learned 2d positional embedding for images.""" mesh = targets.mesh positional_emb_rows_var = mtf.get_variable( mesh, "positional_emb_rows", mtf.Shape([self.pos_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=self.activation_type) positional_emb_cols_var = mtf.get_variable( mesh, "positional_emb_cols", mtf.Shape([self.pos_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=self.activation_type) targets_position_x = mtf.range(mesh, self.rows_dim, dtype=tf.int32) targets_position_y = mtf.range(mesh, self.cols_dim, dtype=tf.int32) position_x = mtf.broadcast( mtf.gather(positional_emb_rows_var, targets_position_x, self.pos_dim), mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim])) position_y = mtf.broadcast( mtf.gather(positional_emb_cols_var, targets_position_y, self.pos_dim), mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim])) return position_x + position_y
def norm(x, axis=None, epsilon=1e-5): axis = default(axis, x.shape[-1]) u = mtf.reduce_mean(x, reduced_dim=axis) s = mtf.reduce_mean(mtf.square(x - u), reduced_dim=axis) u = mtf.broadcast(u, x.shape) s = mtf.broadcast(s, x.shape) return (x - u) * mtf.rsqrt(s + epsilon)
def axial_positional_emb(embd_dim, mesh, params, variable_dtype): # Use axial position encoding axial_dim_1, axial_dim_2 = params["axial_pos_emb"] axial_dim = mtf.Dimension("axial_dim", axial_dim_1 * axial_dim_2) dim_axials = [mtf.Dimension(f"axial_dim_{i}", t) for i, t in enumerate((axial_dim_1, axial_dim_2))] axial_wpe_1 = mtf.get_variable(mesh, "axial_wpe_1", mtf.Shape([dim_axials[0], embd_dim]), initializer=tf.random_normal_initializer(stddev=0.01), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) axial_wpe_2 = mtf.get_variable(mesh, "axial_wpe_2", mtf.Shape([dim_axials[1], embd_dim]), initializer=tf.random_normal_initializer(stddev=0.01), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) axial_wpe_1, axial_wpe_2 = map(lambda t: mtf.broadcast(t, [dim_axials[0], dim_axials[1], embd_dim]), (axial_wpe_1, axial_wpe_2)) wpe = (axial_wpe_1 + axial_wpe_2) / 2 wpe = mtf.reshape(wpe, [axial_dim, embd_dim]) return wpe
def memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh): """memory / key values from all attention paper""" dim_mem_kv = mtf.Dimension("mem_kv_sequence", num_mem_kv) emb_dim = k.shape[-1] mem_std = 1 / math.sqrt(emb_dim.size) mem_k = mtf.get_variable(mesh, "mem_k", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]), initializer=tf.random_normal_initializer(stddev=mem_std), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype, ) mem_v = mtf.get_variable(mesh, "mem_v", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]), initializer=tf.random_normal_initializer(stddev=mem_std), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) mem_k, mem_v = map(lambda t: mtf.broadcast(t, [dim_batch, dim_mem_kv, dim_heads, emb_dim]), (mem_k, mem_v)) mem_k, mem_v = map(lambda t: mtf.rename_dimension(t, "mem_kv_sequence", "sequence"), (mem_k, mem_v)) k = mtf.concat([mem_k, k], "sequence") v = mtf.concat([mem_v, v], "sequence") return k, v
def add_step_timing_signal_func(self, context, x, step): """Add n-dimensional embedding as the step (vertical) timing signal. Args: context: mtf context x: a tensor with shape [batch, length, depth] step: step Returns: a Tensor with the same shape as x. """ if self.recurrence_type == "act": num_steps = self.act_max_steps else: num_steps = self.num_rec_steps channels = x.shape.dims[-1] if self.step_timing_signal_type == "learned": signal = self.get_layer_timing_signal_learned_1d( context, channels, step, num_steps) elif self.step_timing_signal_type == "sinusoid": signal = self.get_layer_timing_signal_sinusoid_1d( context, channels, step, num_steps) if self.add_or_concat_timing_signal == "add": x_with_timing = x + mtf.cast(signal, x.dtype) elif self.add_or_concat_timing_signal == "concat": batch_dim = x.shape.dims[0] out_shape = mtf.Shape([batch_dim] + x.shape.dims[1:]) signal_tiled = mtf.broadcast(signal, out_shape) x_with_timing = mtf.concat( (x, signal_tiled), concat_dim_name=signal_tiled.dimension_names[-1]) return x_with_timing
def get_attn_mask(self, mesh, nd, ns): if not exists(self.attn_mask): i = mtf.range(mesh, nd, tf.int32) + ns.size - nd.size j = mtf.range(mesh, ns, tf.int32) i, j = map(lambda t: mtf.broadcast(t, [nd, ns]), (i, j)) self.attn_mask = mtf.cast(mtf.less( i, j), self.variable_dtype.activation_dtype) * -1e10 return self.attn_mask
def biasmask_attn_weights(mesh, nd, ns, variable_dtype): # The old mask_attn_weights applied directly to the QK; # this returns a bias that the attention code from mtf adds to the attention matrix. # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. # n_src and n_dest are both the same, i.e equal to sequence length # We rename ns because we want bias to have shape [batch, heads, memory_length, sequence] to match up with QK^T # Information flows from k and v (memory_length) to q (sequence) i = mtf.range(mesh, nd, tf.int32) + ns.size - nd.size j = mtf.range(mesh, ns, tf.int32) i, j = map(lambda t: mtf.broadcast(t, [nd, ns]), (i, j)) dtype = variable_dtype.activation_dtype return mtf.cast(mtf.less(i, j), dtype) * -1e10
def add_position_timing_signal_func(self, context, x, step): """Add n-dimensional embedding as the position (horizontal) timing signal. Args: context: mtf context x: a tensor with shape [batch, length, depth] step: step Returns: a Tensor with the same shape as x. """ if not self.position_start_index: index = 0 elif self.position_start_index == "random": # Shift all positions randomly # TODO(dehghani): What would be reasonable for max number of shift? index = mtf.random_uniform(context.mesh, [], maxval=x.shape.dims[1].size, dtype=tf.int32) elif self.position_start_index == "step": # Shift positions based on the step if self.recurrence_type == "act": num_steps = self.act_max_steps else: num_steps = self.num_rec_steps index = mtf.cast(x.shape.dims[1].size * step / num_steps, dtype=tf.int32) length = context.length_dim channels = context.model.model_dim signal = self.get_timing_signal_1d(context, length, channels, start_index=index) if self.add_or_concat_timing_signal == "add": x_with_timing = x + mtf.cast(signal, x.dtype) # Unimplemented if self.add_or_concat_timing_signal == "concat": batch_dim = x.shape.dims[0] out_shape = mtf.Shape([batch_dim] + signal.shape.dims[1:]) signal_tiled = mtf.broadcast(signal, out_shape) x_with_timing = mtf.concat( (x, signal_tiled), concat_dim_name=signal_tiled.dimension_names[-1]) return x_with_timing
def _noisy_targets_from_spec(self, targets, noising_spec, losses=None): if noising_spec["type"] == "mask": # Replace a randomly-chosen noising_spec["prob"] of input tokens with 0. return targets * mtf.cast( mtf.greater(mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"]), targets.dtype) elif noising_spec["type"] == "random_zipfian": # Replace a randomly-chosen noising_spec["prob"] of input tokens. # Rather than drawing the replacement tokens uniformly, we sample from # a distribution favoring lower token-ids, assuming that the ids have # been assigned in frequency order. The probability of choosing an # id is proportional to 1/(id+10) logits = mtf.log(1.0 / (mtf.range( targets.mesh, self.targets_vocab_dim, dtype=tf.float32) + 10.0)) logits = mtf.broadcast(logits, new_shape=targets.shape + logits.shape) r = mtf.sample_with_temperature(logits, self.targets_vocab_dim) use_noise = mtf.less( mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"]) return mtf.where(use_noise, r, targets) elif noising_spec["type"] == "transformer": # Train a small transformer to fill in masked out values, then # sample from it. hparams = self._hparams if hparams.mode != tf.estimator.ModeKeys.TRAIN: raise NotImplementedError("Not implemented") noiser_hparams = copy.copy(self._hparams) noiser_hparams.del_hparam("mode") noiser_hparams.override_from_dict(noising_spec["overrides"]) with tf.variable_scope("noiser"): noiser = MtfTransformer(noiser_hparams, mode=hparams.mode, problem_hparams=self._problem_hparams) logits, loss = noiser._mtf_model_fn( # pylint: disable=protected-access self._original_features, targets.mesh) samples = mtf.sample_with_temperature(logits, self.targets_vocab_dim) losses.append(loss) return samples else: raise ValueError("unknown noising spec %s" % noising_spec)
def attention(x, dim_head, dim_features_head, scope='attn', causal=False): with tf.variable_scope(scope): mesh, batch, seq, dim = x.mesh, *x.shape dim_heads = mtf.Dimension('dim_heads', dim_head.size * dim_features_head.size) dim_intermediate = mtf.Dimension('qkv_dimension', dim_heads.size * 3) qkv = linear(x, dim_intermediate, bias=False, scope='to_qkv') q, k, v = mtf.split(qkv, dim_intermediate, 3) q, k, v = map( lambda t: mtf.reshape(t, [batch, seq, dim_head, dim_features_head] ), (q, k, v)) q, k, v = map( lambda t: mtf.transpose( t, [batch, dim_head, seq, dim_features_head]), (q, k, v)) k, v = map( lambda t: mtf.rename_dimension(t, seq.name, 'memory_length'), (k, v)) mem_len_dim = v.shape[-2] dots = mtf.layers.us_einsum([q, k], [batch, dim_head, seq, mem_len_dim]) if causal: i = mtf.range(mesh, seq, tf.int32) j = mtf.range(mesh, mem_len_dim, tf.int32) i, j = map(lambda t: mtf.broadcast(t, [seq, mem_len_dim]), (i, j)) mask = mtf.less(i + mem_len_dim.size - seq.size, j) mask = mtf.cast(mask, tf.float32) * -1e10 dots += mask attn = mtf.softmax(dots, mem_len_dim) out = mtf.einsum([attn, v], [batch, dim_head, seq, dim_features_head]) out = mtf.transpose(out, [batch, seq, dim_head, dim_features_head]) out = mtf.reshape(out, [batch, seq, dim_heads]) combined_out = linear(out, dim, scope='combine_output') return combined_out
def _noisy_targets_from_spec(self, targets, noising_spec, losses=None): if noising_spec["type"] == "mask": # Replace a randomly-chosen noising_spec["prob"] of input tokens with 0. return targets * mtf.cast( mtf.greater(mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"]), targets.dtype) elif noising_spec["type"] == "random_zipfian": # Replace a randomly-chosen noising_spec["prob"] of input tokens. # Rather than drawing the replacement tokens uniformly, we sample from # a distribution favoring lower token-ids, assuming that the ids have # been assigned in frequency order. The probability of choosing an # id is proportional to 1/(id+10) logits = mtf.log(1.0 / (mtf.range( targets.mesh, self.targets_vocab_dim, dtype=tf.float32) + 10.0)) logits = mtf.broadcast(logits, new_shape=targets.shape + logits.shape) r = mtf.sample_with_temperature(logits, self.targets_vocab_dim) use_noise = mtf.less( mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"]) return mtf.where(use_noise, r, targets) elif noising_spec["type"] == "transformer": # Train a small transformer to fill in masked out values, then # sample from it. hparams = self._hparams if hparams.mode != tf.estimator.ModeKeys.TRAIN: raise NotImplementedError("Not implemented") noiser_hparams = copy.copy(self._hparams) noiser_hparams.del_hparam("mode") noiser_hparams.override_from_dict(noising_spec["overrides"]) with tf.variable_scope("noiser"): noiser = MtfTransformer( noiser_hparams, mode=hparams.mode, problem_hparams=self._problem_hparams) logits, loss = noiser._mtf_model_fn( # pylint: disable=protected-access self._original_features, targets.mesh) samples = mtf.sample_with_temperature(logits, self.targets_vocab_dim) losses.append(loss) return samples else: raise ValueError("unknown noising spec %s" % noising_spec)
def attn(x, scope, n_state, *, attention_type, params, bias, dim_seq, memory_length_dim, variable_dtype, context=None): # x :: [batch, seq, n_embd] x_shape, dim_batch, *_, dim_embd, mesh = x.shape, *x.shape, x.mesh # n_state is the same as config["n_embd"], which is also the same as dim_embd. assert n_state.size % params["n_head"] == 0 dim_heads = mtf.Dimension("heads", params["n_head"]) num_mem_kv = params.get("num_mem_kv", 0) use_num_mem_kv = num_mem_kv > 0 with tf.variable_scope(scope): # Compute attention inputs dim_kv = mtf.Dimension("features_per_head", params["n_embd"] // params["n_head"]) mtfparams = mtf.transformer.attention.attention_params_simple( x.mesh, io_dim=dim_embd, kv_dim=dim_kv, heads_dim=dim_heads, variable_dtype=variable_dtype ) q = mtfparams.compute_q(x) k = mtfparams.compute_k(x) v = mtfparams.compute_v(x) if is_incremental_inference(context): one_hot = mtf.one_hot(context.position - 1, dim_seq, dtype=variable_dtype.master_dtype) inv_one_hot = 1.0 - one_hot old_k, old_v = context.get_states(2) k = old_k * inv_one_hot + k * one_hot v = old_v * inv_one_hot + v * one_hot if exists(context): context.record_new_states([k, v]) with tf.variable_scope("attention"): if attention_type == "local": # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights. radius = params.get("local_attention_radius", 256) if is_incremental_inference(context): q *= one_hot a = mtf_transformer.attention.local_attention_1d( q, k, v, length_dim=k.shape[1], key_dim=dim_kv, value_dim=dim_kv, radius=radius, length_dim_num_splits=1, fully_autoregressive=params["causal"], attention_kwargs={}, ) if is_incremental_inference(context): a = mtf.gather(a, context.position - 1, dim_seq) elif attention_type == "global": # TODO: pass in fake context # Broadcast mask bias across batch and heads if exists(bias): if not is_incremental_inference(context): broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-2], bias.shape[-1]]) else: # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position bias = mtf.gather(bias, context.position - 1, dim_seq) broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-1]]) # memory key / values, from all-attention paper if use_num_mem_kv: k, v = memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh) k = mtf.replace_dimensions(k, k.shape[1], memory_length_dim) v = mtf.replace_dimensions(v, v.shape[1], memory_length_dim) attn_dropout_rate = params["attn_dropout"] if params["mode"] == "train" else 0 a = mtf_transformer.attention.attention( q, k, v, memory_length_dim=memory_length_dim, key_dim=dim_kv, value_dim=dim_kv, bias=broadcasted_bias, dropout_rate=attn_dropout_rate ) elif attention_type == "linear": linear_attn_fn = causal_linear_attention if params["causal"] else linear_attention a = linear_attn_fn(q, k, v) else: raise NotImplementedError("Unknown attention type {}!".format(attention_type)) with tf.variable_scope("compute_output"): a = mtfparams.compute_output(a, x_shape) with tf.variable_scope("compute_output_bias"): b = mtf.get_variable(x.mesh, "o_b", [dim_embd], initializer=tf.constant_initializer(0), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) a += b if params["mode"] == "train" and params["res_dropout"] > 0: a = mtf.dropout(a, rate=params["res_dropout"], name="res_dropout") return a
def _call_internal(self, context, inputs, targets=None, attributes=None, z=None): """Compute logits based on inputs (all positions in parallel). Also updates context if applicable. Args: context: a Context inputs: a Tensor targets: an optional Tensor attributes: an optional Tensor Returns:g logits: a Tensor with shape [<batch_dims>, length_dim, output_vocab_dim] """ mesh = inputs.mesh if self.ensemble_dim and self.ensemble_dim not in inputs.shape.dims: # Training an ensemble where all models are trained on the same examples. inputs = mtf.broadcast(inputs, [self.ensemble_dim] + inputs.shape.dims) if self.ensemble_dim not in attributes.shape.dims: attributes = mtf.broadcast(attributes, [self.ensemble_dim] + attributes.shape.dims) if targets: targets = mtf.broadcast(targets, [self.ensemble_dim] + targets.shape.dims) if "embedding" in context.shared_params: vocab_embedding = context.shared_params["embedding"] else: vocab_embedding = VocabEmbedding(mesh, self.input_vocab_dim, self.model_dim, context.variable_dtype, name="embedding", ensemble_dim=self.ensemble_dim) x = vocab_embedding.ids_to_embedding(inputs) if self.positional_embedding: if "positional_embedding" in context.shared_params: pos_emb_var = context.shared_params["positional_embedding"] else: pos_emb_var = mtf.layers.embedding_weights( mesh, self.max_length_dim, self.model_dim, context.variable_dtype, "positional_embedding", ensemble_dim=self.ensemble_dim) if (context.length_dim is not None and context.length_dim.size > self.max_length_dim.size): message = ( "Length dimenison exceeds size of positional embedding table. " "length_dim.size > max_length_dim.size %s vs %s." % (context.length_dim, self.max_length_dim)) if context.position_is_default: # Definitely getting overflow in this case. raise ValueError(message) else: tf.logging.warning( message + " This may be OK if there are several shorter sequences packed " "together. Otherwise, the later positions will get zeros." ) if context.position_is_default: pos_emb = mtf.rename_dimension( mtf.slice(pos_emb_var, 0, context.length_dim.size, self.max_length_dim.name), self.max_length_dim.name, context.length_dim.name) else: pos_emb = mtf.gather(pos_emb_var, context.position, self.max_length_dim, output_shape=x.shape) x += pos_emb if self.attribute_embedding: if "attribute_embedding" in context.shared_params: att_emb_var = context.shared_params["attribute_embedding"] else: att_emb_var = mtf.layers.embedding_weights( mesh, self.attribute_dim, self.model_dim, context.variable_dtype, "attribute_embedding", ensemble_dim=self.ensemble_dim) att_emb = mtf.gather(att_emb_var, attributes, self.attribute_dim, output_shape=x.shape) # Addition of x and attribute # x *= LAMBDA_ATTRIBUTE * sty_emb # # Concatenation of x and attribute x_attribute = mtf.concat([x, att_emb], self.model_dim.name) x = mtf.layers.dense(x_attribute, self.model_dim, activation=None, variable_dtype=context.variable_dtype, name="comb_x_attribute") if z: z = mtf.layers.dense(z, self.model_dim, activation=None, variable_dtype=context.variable_dtype, name="z") # raise ValueError("x shape=%s , z shape=%s" % (x.shape, z.shape)) x += z x = self.layer_stack.call(context, x) if self.output_vocab_dim is None: return x if self.shared_embedding_and_softmax_weights: logits = vocab_embedding.hidden_to_logits(x) else: logits = mtf.layers.dense(x, self.output_vocab_dim, use_bias=False, variable_dtype=context.variable_dtype, reduced_dims=x.shape.dims[-1:], name="logits") if targets is not None and context.losses is not None: context.losses.append( self._compute_loss(context, logits, targets, self.output_vocab_dim)) if self.ensemble_dim: logits = reduce_ensemble_logits(logits, self.ensemble_dim, self.output_vocab_dim) return logits
def attention(self, x, n_state, mask, attention_type="global", name="attn"): # x :: [batch, seq, n_embd] batch_dim, seq_dim, embd_dim = x_shape = x.shape assert n_state.size % self.n_heads == 0, "n_state must be divisible by n_heads" with tf.variable_scope(name): # Compute attention inputs mtfparams = mtf.transformer.attention.attention_params_simple( x.mesh, io_dim=self.dimensions["embed_dim"], kv_dim=self.dimensions["kv_dim"], heads_dim=self.dimensions["heads_dim"], variable_dtype=self.variable_dtype) q = mtfparams.compute_q(x) k = mtfparams.compute_k(x) v = mtfparams.compute_v(x) if self.is_incremental_inference: one_hot = mtf.one_hot(self.context.position - 1, seq_dim, dtype=self.variable_dtype.master_dtype) inv_one_hot = 1.0 - one_hot old_k, old_v = self.context.get_states(2) k = old_k * inv_one_hot + k * one_hot v = old_v * inv_one_hot + v * one_hot if exists(self.context): self.context.record_new_states([k, v]) with tf.variable_scope("attention"): if attention_type == "local": # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights. radius = self.params.get("local_attention_radius", 256) if self.is_incremental_inference: q *= one_hot a = mtf_transformer.attention.local_attention_1d( q, k, v, length_dim=k.shape[1], key_dim=self.dimensions["kv_dim"], value_dim=self.dimensions["kv_dim"], radius=radius, length_dim_num_splits=1, fully_autoregressive=True, attention_kwargs={}, ) if self.is_incremental_inference: a = mtf.gather(a, self.context.position - 1, seq_dim) elif attention_type == "global": if exists(mask): if not self.is_incremental_inference: broadcasted_mask = mtf.broadcast( mask, [ batch_dim, self.dimensions["heads_dim"], mask.shape[-2], mask.shape[-1] ]) # TODO: not sure this is correct else: # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position mask = mtf.gather(mask, self.context.position - 1, seq_dim) broadcasted_mask = mtf.broadcast( mask, [ batch_dim, self.dimensions["heads_dim"], mask.shape[-1] ]) k = mtf.replace_dimensions( k, k.shape[1], self.dimensions["memory_len_dim"]) v = mtf.replace_dimensions( v, v.shape[1], self.dimensions["memory_len_dim"]) attn_dropout_rate = self.params.get( "attention_dropout", 0) if self.mode == "train" else 0 a = mtf_transformer.attention.attention( q, k, v, memory_length_dim=self.dimensions["memory_len_dim"], key_dim=self.dimensions["kv_dim"], value_dim=self.dimensions["kv_dim"], bias=broadcasted_mask, dropout_rate=attn_dropout_rate) else: raise NotImplementedError( "Unknown attention type {}!".format(attention_type)) with tf.variable_scope("compute_output"): a = mtfparams.compute_output(a, x_shape) with tf.variable_scope("compute_output_bias"): b = mtf.get_variable( x.mesh, "o_b", [embd_dim], initializer=tf.constant_initializer(0), master_dtype=self.variable_dtype.master_dtype, slice_dtype=self.variable_dtype.slice_dtype, activation_dtype=self.variable_dtype.activation_dtype) a += b residual_dropout = self.params.get("residual_dropout", 0) if self.mode == "train" and residual_dropout > 0: a = mtf.dropout(a, rate=residual_dropout, name="res_dropout") return a