def create_positional_emb_2d(self, targets): """Learned 2d positional embedding for images.""" mesh = targets.mesh positional_emb_rows_var = mtf.get_variable( mesh, "positional_emb_rows", mtf.Shape([self.pos_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=self.activation_type) positional_emb_cols_var = mtf.get_variable( mesh, "positional_emb_cols", mtf.Shape([self.pos_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=self.activation_type) targets_position_x = mtf.range(mesh, self.rows_dim, dtype=tf.int32) targets_position_y = mtf.range(mesh, self.cols_dim, dtype=tf.int32) position_x = mtf.broadcast( mtf.gather(positional_emb_rows_var, targets_position_x, self.pos_dim), mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim])) position_y = mtf.broadcast( mtf.gather(positional_emb_cols_var, targets_position_y, self.pos_dim), mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim])) return position_x + position_y
def _embedding_and_softmax_vars(self, mesh): hparams = self._hparams if hparams.transformer_type == "encoder": targets_embedding_var = None else: targets_embedding_var = mtf.get_variable( mesh, "targets_embedding", mtf.Shape([self.targets_vocab_dim, self.model_dim]), initializer=tf.random_normal_initializer(), master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, activation_dtype=self.activation_dtype) if hparams.transformer_type == "decoder": inputs_embedding_var = None else: if hparams.shared_embedding and targets_embedding_var: inputs_embedding_var = targets_embedding_var else: inputs_embedding_var = mtf.get_variable( mesh, "inputs_embedding", mtf.Shape([self.inputs_vocab_dim, self.model_dim]), initializer=tf.random_normal_initializer(), master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, activation_dtype=self.activation_dtype) if hparams.shared_embedding_and_softmax_weights: softmax_var = (targets_embedding_var or inputs_embedding_var) * ( self.model_dim.size ** -0.5) else: softmax_var = mtf.get_variable( mesh, "softmax", mtf.Shape([self.targets_vocab_dim, self.model_dim]), initializer=tf.random_normal_initializer( stddev=self.model_dim.size**-0.5), master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, activation_dtype=self.activation_dtype) positional_embedding_var = mtf.get_variable( mesh, "positional_embedding", mtf.Shape([self.max_length_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=self.activation_dtype) return (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var)
def synthetic_attention(q, k, v, memory_length_dim, key_dim, value_dim, bias=None, dropout_rate=0.0, dropout_broadcast_dims=None, extra_logit=None, synthesize=True, synthesize_mode="random_plus_alpha", factorized_dim=16, max_length=512, context=None): """Synthetic Attention from Synthesizers (https://arxiv.org/abs/2005.00743). key_dim is a Dimension representing the channels in the queries and keys value_dim is a Dimension representing the channels in values memory_length_dim is a Dimension representing the different key/value pairs. Dimensions of q: other_query_dims + {key_dim} Dimensions of k: other_memory_dims + {memory_length_dim, key_dim} Dimensions of v: other_memory_dims + {memory_length_dim, value_dim} other_memory_dims is a subset of other_query_dims Typically, other_query_dims={batch, heads, length} Typically, other_memory_dims={batch, heads} Args: q: a Tensor k: a Tensor v: a Tensor memory_length_dim: a Dimension key_dim: a Dimension value_dim: a Dimension bias: a Tensor to be added into the attention logits. dropout_rate: a float. dropout_broadcast_dims: an optional list of mtf.Dimension extra_logit: an optional scalar or tensor synthesize: flag to use synthetic attention or not synthesize_mode: which variant of synthesizer to use factorized_dim: factorized dim for synthesizers max_length: max length of input sequence context: context since we need context mode Returns: Tensor with shape q.shape - key_dim + value_dim """ if synthesize: num_heads = v.shape.get_dim_by_name("heads") tf.logging.info("Using synthesizer") if synthesize_mode == "random": tf.logging.info("Using Random Synthesizers") r_shape = mtf.Shape([mtf.Dimension("length", max_length), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", max_length)]) r = mtf.get_variable(context.mesh, "R", r_shape, initializer=None, dtype=context.variable_dtype) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length")) else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, "length") logits = r r_shape = logits.shape elif synthesize_mode == "factorized": tf.logging.info("Using Factorized Random Synthesizers") k = factorized_dim r1_shape = mtf.Shape([mtf.Dimension("tmp", k), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r2_shape = mtf.Shape([mtf.Dimension("tmp", k), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r_shape = mtf.Shape([mtf.Dimension("length", 512), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r1 = mtf.get_variable(context.mesh, "R1", r1_shape, initializer=None, dtype=context.variable_dtype) r2 = mtf.get_variable(context.mesh, "R2", r2_shape, initializer=None, dtype=context.variable_dtype) r = mtf.einsum([r1, r2], r_shape) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length")) else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, "length") logits = r elif synthesize_mode == "dense_minus": # Dense Synthesizer Model tmp_dim = mtf.Dimension("memory_length", max_length) logits = mtf.layers.dense(mtf.relu(q), [tmp_dim], use_bias=False, name="pi", reduced_dims=[key_dim], variable_dtype=None) logits = mtf.slice(logits, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": pass else: length_dim = q.shape.get_dim_by_name("length") logits = mtf.slice(logits, 0, length_dim.size, "length") elif synthesize_mode == "random_plus_alpha" or \ synthesize_mode == "random_plus": # Mixture Random Synthesizer with learnable Alpha tf.logging.info("Using Random Plus Alpha") logits = mtf.einsum([q, k], reduced_dims=[key_dim]) num_heads = logits.shape.get_dim_by_name("heads") r_shape = mtf.Shape([mtf.Dimension("length", 512), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r = mtf.get_variable(context.mesh, "R", r_shape, initializer=None, dtype=context.variable_dtype) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length")) else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, length_dim.name) if "alpha" in synthesize_mode: alpha = mtf.get_variable(context.mesh, "alpha", mtf.Shape([mtf.Dimension("alpha", 1)]), initializer=tf.zeros_initializer(), dtype=context.variable_dtype) alpha = mtf.sigmoid(alpha) logits = ((1-alpha) * logits) + (alpha * r) else: logits = logits + r elif synthesize_mode == "dense_plus_alpha" or \ synthesize_mode == "dense_plus": # Mixture Dense Synthesizer with learnable alpha tf.logging.info("Using Dense Plus Alpha Scaling") logits = mtf.einsum([q, k], reduced_dims=[key_dim]) tmp_dim = mtf.Dimension("memory_length", 512) r = mtf.layers.dense(mtf.relu(q), [tmp_dim], use_bias=False, name="pi", reduced_dims=[key_dim], variable_dtype=None) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": pass else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, "length") if "alpha" in synthesize_mode: alpha = mtf.get_variable(context.mesh, "alpha", mtf.Shape([mtf.Dimension("alpha", 1)]), initializer=tf.zeros_initializer(), dtype=context.variable_dtype) alpha = mtf.sigmoid(alpha) logits = ((1-alpha) * logits) + (alpha * r) else: logits = logits + r if bias is not None: logits += bias weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit) weights = mtf.dropout( weights, context.train, 1.0 - dropout_rate, noise_shape=weights.shape - dropout_broadcast_dims) if synthesize and "plus" not in synthesize_mode: if synthesize_mode == "dense_minus": outputs_shape = mtf.Shape(q.shape.dims[:-1] + [value_dim]) else: outputs_shape = mtf.Shape(q.shape.dims[:-1] + [num_heads, value_dim]) else: outputs_shape = q.shape - [key_dim] + value_dim outputs = mtf.einsum([weights, v], outputs_shape) return outputs
def attention(self, x, n_state, mask, attention_type="global", name="attn"): # x :: [batch, seq, n_embd] batch_dim, seq_dim, embd_dim = x_shape = x.shape assert n_state.size % self.n_heads == 0, "n_state must be divisible by n_heads" with tf.variable_scope(name): # Compute attention inputs mtfparams = mtf.transformer.attention.attention_params_simple( x.mesh, io_dim=self.dimensions["embed_dim"], kv_dim=self.dimensions["kv_dim"], heads_dim=self.dimensions["heads_dim"], variable_dtype=self.variable_dtype) q = mtfparams.compute_q(x) k = mtfparams.compute_k(x) v = mtfparams.compute_v(x) if self.is_incremental_inference: one_hot = mtf.one_hot(self.context.position - 1, seq_dim, dtype=self.variable_dtype.master_dtype) inv_one_hot = 1.0 - one_hot old_k, old_v = self.context.get_states(2) k = old_k * inv_one_hot + k * one_hot v = old_v * inv_one_hot + v * one_hot if exists(self.context): self.context.record_new_states([k, v]) with tf.variable_scope("attention"): if attention_type == "local": # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights. radius = self.params.get("local_attention_radius", 256) if self.is_incremental_inference: q *= one_hot a = mtf_transformer.attention.local_attention_1d( q, k, v, length_dim=k.shape[1], key_dim=self.dimensions["kv_dim"], value_dim=self.dimensions["kv_dim"], radius=radius, length_dim_num_splits=1, fully_autoregressive=True, attention_kwargs={}, ) if self.is_incremental_inference: a = mtf.gather(a, self.context.position - 1, seq_dim) elif attention_type == "global": if exists(mask): if not self.is_incremental_inference: broadcasted_mask = mtf.broadcast( mask, [ batch_dim, self.dimensions["heads_dim"], mask.shape[-2], mask.shape[-1] ]) # TODO: not sure this is correct else: # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position mask = mtf.gather(mask, self.context.position - 1, seq_dim) broadcasted_mask = mtf.broadcast( mask, [ batch_dim, self.dimensions["heads_dim"], mask.shape[-1] ]) k = mtf.replace_dimensions( k, k.shape[1], self.dimensions["memory_len_dim"]) v = mtf.replace_dimensions( v, v.shape[1], self.dimensions["memory_len_dim"]) attn_dropout_rate = self.params.get( "attention_dropout", 0) if self.mode == "train" else 0 a = mtf_transformer.attention.attention( q, k, v, memory_length_dim=self.dimensions["memory_len_dim"], key_dim=self.dimensions["kv_dim"], value_dim=self.dimensions["kv_dim"], bias=broadcasted_mask, dropout_rate=attn_dropout_rate) else: raise NotImplementedError( "Unknown attention type {}!".format(attention_type)) with tf.variable_scope("compute_output"): a = mtfparams.compute_output(a, x_shape) with tf.variable_scope("compute_output_bias"): b = mtf.get_variable( x.mesh, "o_b", [embd_dim], initializer=tf.constant_initializer(0), master_dtype=self.variable_dtype.master_dtype, slice_dtype=self.variable_dtype.slice_dtype, activation_dtype=self.variable_dtype.activation_dtype) a += b residual_dropout = self.params.get("residual_dropout", 0) if self.mode == "train" and residual_dropout > 0: a = mtf.dropout(a, rate=residual_dropout, name="res_dropout") return a
def compute_bias(self, context, memory_position, x): """Compute attention bias. Args: context: a transformer.Context memory_position: an int32 tensor containing memory_length dimension. x: a Tensor - the query antecedent - required for relative attention Returns: a Tensor or None """ min_relative_position = self.min_relative_position(context) max_relative_position = self.max_relative_position(context) # we can often cache the result of this function between similar layers can_cache = (self.relative_attention_type is None or self.relative_attention_type == "bias_shared") if can_cache: cache_key = ("self_attention_mask", min_relative_position, max_relative_position, self.relative_attention_type, self.num_heads) if cache_key in context.cache: return context.cache[cache_key] biases = [] relative_position = memory_position - context.position if min_relative_position is not None: visible = mtf.greater_equal(relative_position, min_relative_position) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) if max_relative_position is not None: visible = mtf.less_equal(relative_position, max_relative_position) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) if context.read_priority is not None: visible = mtf.greater_equal( context.read_priority, mtf.layers.rename_length_to_memory_length( context.write_priority)) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) sequence_id = None # Subsequence id should only be set if we are in the decoder and have # multiple targets per input. This will allow each sub-target to only attend # to itself. if isinstance(context.subsequence_id, mtf.Tensor): sequence_id = context.subsequence_id elif isinstance(context.sequence_id, mtf.Tensor): sequence_id = context.sequence_id if (sequence_id is not None and context.length_dim in sequence_id.shape): visible = mtf.equal( sequence_id, self.rename_length_to_memory_length(sequence_id, context)) biases.append( attention.visibility_mask_to_attention_bias( visible, context.activation_dtype)) if self.relative_attention_type is not None: buckets_dim = mtf.Dimension("buckets", self.relative_attention_num_buckets) heads_dim = mtf.Dimension("heads", self.num_heads) bidirectional = not context.model.fully_autoregressive rp_bucket = _relative_position_bucket(relative_position, bidirectional=bidirectional, num_buckets=buckets_dim.size) if (self.relative_attention_type == "bias" or self.relative_attention_type == "bias_shared"): bias_shape = [heads_dim, buckets_dim] if context.model.ensemble_dim: bias_shape = [context.model.ensemble_dim] + bias_shape values = mtf.get_variable(context.mesh, "relative_attention_bias", bias_shape, dtype=context.variable_dtype) elif self.relative_attention_type == "contextual": if context.model.ensemble_dim: expert_dims = [context.model.ensemble_dim] else: expert_dims = None values = layers.dense(x, [buckets_dim, heads_dim], variable_dtype=context.variable_dtype, name="relative_attention_contextual", expert_dims=expert_dims) else: raise ValueError( "unrecognized relative_attention_type \"%s\"" % self.relative_attention_type) biases.append(mtf.gather(values, rp_bucket, buckets_dim)) ret = mtf.add_n(biases) if biases else None if can_cache: context.cache[cache_key] = ret return ret
def __init__(self, mesh: mtf.Mesh, vocab_dim: mtf.Dimension, output_dim: mtf.Dimension, variable_dtype: mtf.VariableDType, name: str, ensemble_dim: mtf.Dimension, extra_ids: int = 0, dropout_rate: float = 0.0, gate_embedding_size: int = gin.REQUIRED, frequent_token_fraction: float = 0.1, noise_std_dev: float = 0.0): """Configurable embedding for the vocabulary. Most of the arguments get passed to `mtf.layers.embedding_weights`. Mixtape shares gates for low frequency tokens to improve efficiency. Since our vocabs are sorted in decreasing order of frequency with sentinels appended to the end, we need to do a little trick to ensure that the sentinels are treated as high frequency. If you want to treat the sentinels as low frequency tokens, then pass in zero for `extra_ids`. Args: mesh: the mesh used to layout the tensors. vocab_dim: the dimension corresponding to vocabulary. output_dim: the dimension corresponding to the model hidden states. variable_dtype: the datatype information for the variables used in the embedding tensors. name: a name to base variable names off of. ensemble_dim: the dimension used for ensembling. Absolutely no guarantees that this code will work with ensembling. extra_ids: a non-negative integer, the number of sentinels at the end of the vocab. dropout_rate: a float between 0 and 1, the rate to use for dropout. gate_embedding_size: a positive integer, the size to use for embedding for the gates. It is usually chosen to be much smaller than d_model. frequent_token_fraction: a float between 0 and 1, what fraction of tokens to consider as high frequency and not share gates for. noise_std_dev: a non-negative float, the standard deviation of the Gaussian noise to add to the pre-activation priors. """ self._extra_ids = extra_ids self._dropout_rate = dropout_rate self._noise_std_dev = noise_std_dev self._mesh = mesh self._vocab_dim = vocab_dim self._frequent_vocab_dim = mtf.Dimension( vocab_dim.name, int(frequent_token_fraction * vocab_dim.size)) self._rare_vocab_dim = mtf.Dimension( vocab_dim.name, vocab_dim.size - self._frequent_vocab_dim.size) self._output_dim = output_dim self._copy_output_dim = mtf.Dimension("_{}_copy".format(output_dim.name), output_dim.size) self._pre_gates_dim = mtf.Dimension("gates", 3) self._gates_dim = mtf.Dimension("gates", 4) self._gate_embedding_dim = mtf.Dimension("gate_embedding", gate_embedding_size) self._embedding_weights = mtf.layers.embedding_weights( mesh=mesh, vocab_dim=vocab_dim, output_dim=output_dim, variable_dtype=variable_dtype, name="{}_embedding_weights".format(name), ensemble_dim=ensemble_dim) ensemble_dims = [ensemble_dim] if ensemble_dim else [] self._context_weights = mtf.layers.embedding_weights( mesh=mesh, vocab_dim=self._copy_output_dim, output_dim=output_dim, variable_dtype=variable_dtype, name="{}_context_weights".format(name), ensemble_dim=ensemble_dims + [self._gates_dim]) self._context_weights_bias = mtf.get_variable( mesh, name="{}_context_weights_bias".format(name), shape=mtf.Shape(ensemble_dims + [self._gates_dim, output_dim]), dtype=variable_dtype, initializer=tf.zeros_initializer()) self._prior_weights = mtf.layers.embedding_weights( mesh=mesh, vocab_dim=self._gate_embedding_dim, output_dim=output_dim, variable_dtype=variable_dtype, name="{}_prior_weights".format(name), ensemble_dim=ensemble_dims + [self._pre_gates_dim]) self._prior_weights_bias = mtf.get_variable( mesh, name="{}_prior_weights_bias".format(name), shape=mtf.Shape(ensemble_dims + [self._pre_gates_dim, self._gate_embedding_dim]), dtype=variable_dtype, initializer=tf.zeros_initializer()) self._prior_vocab_vector = mtf.get_variable( mesh, name="{}_prior_vocab_vector".format(name), shape=mtf.Shape(ensemble_dims + [self._frequent_vocab_dim, self._gate_embedding_dim]), dtype=variable_dtype, initializer=tf.random_normal_initializer()) self._prior_gates_vector = mtf.get_variable( mesh, name="{}_prior_gates_vector".format(name), shape=mtf.Shape(ensemble_dims + [self._pre_gates_dim, output_dim]), dtype=variable_dtype, initializer=tf.random_normal_initializer()) self._prior_bias = mtf.get_variable( mesh, name="{}_prior_bias".format(name), shape=mtf.Shape(ensemble_dims + [self._frequent_vocab_dim, self._pre_gates_dim]), dtype=variable_dtype, initializer=tf.random_normal_initializer())
def add_var(self, name, *shape): return mtf.get_variable( name, shape=shape, initializer=create_initializer(), )
def _layer_stack(self, x, layers, encoder_output=None, self_attention_mask=None, encdec_attention_mask=None, losses=None, step_num=None, encdec_tensors=None, self_attention_k=None, self_attention_v=None): """Encoder or decoder stack. Args: x: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim] layers: an list of strings encoder_output: an optional mtf.Tensor with shape [<batch_dims>, encoder_length_dim, model_dim] self_attention_mask: an optional mtf.Tensor with shape [batch, length_dim, memory_length_dim] containing values 0 or -inf. encdec_attention_mask: an optional mtf.Tensor with shape [batch, length_dim, encoder_length_dim] containing values 0 or -inf. losses: a list to be appended-to step_num: an optional mtf integer Scalar (used in incrmenental mode) encdec_tensors: an optional list of num_layers tuples, each of the form (q_var, o_var, k, v), (used in incremental mode) self_attention_k: an optional list of num_layers Tensors each with shape [batch, heads, memory_length, kv_channels] (incremental mode) self_attention_v: an optional list of num_layers Tensors each with shape [batch, heads, memory_length, kv_channels] (incremental mode) Returns: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim] Raises: ValueError: if hparams make no sense """ hparams = self._hparams is_incremental = (step_num is not None) def layer_prepostprocess_dropout(x): if is_incremental: return x return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape(self.batch_dims + [self.model_dim])) num_layers = len(layers) num_layer_norms = num_layers + 1 layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms) layer_norm_combined_var = mtf.get_variable( x.mesh, "layer_norm_scale", mtf.Shape([layer_norms_dim, self.model_dim]), initializer=tf.ones_initializer(), activation_dtype=x.dtype) layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim) def normalize(x): scale = layer_norm_vars.pop(0) variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim) return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale if is_incremental: new_self_attention_k = [] new_self_attention_v = [] for lnum, layer_type in enumerate(layers): with tf.variable_scope("%s_%d" % (layer_type, lnum)): if layer_type == "att": # Self attention layer if is_incremental: self_att_num = len(new_self_attention_k) y, new_k, new_v = mtf.layers.multihead_self_attention_incremental( normalize(x), prev_k=self_attention_k[self_att_num], prev_v=self_attention_v[self_att_num], step_num=step_num, master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="att") new_self_attention_k.append(new_k) new_self_attention_v.append(new_v) x += y else: x += layer_prepostprocess_dropout( mtf.layers.multihead_attention( normalize(x), None, self_attention_mask, self.kv_dim, self.heads_dim, dropout=hparams.attention_dropout, dropout_broadcast_dims=[self.length_dim], master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="att")) elif layer_type == "enc_att": # Encoder-Decoder attention layer if is_incremental: # Encoder-Decoder attention layer q_var, o_var, k, v = encdec_tensors[lnum] x += mtf.layers.multihead_encdec_attention_incremental( normalize(x), q_var, o_var, k, v, encdec_attention_mask, name="enc_att") else: x += layer_prepostprocess_dropout( mtf.layers.multihead_attention( normalize(x), encoder_output, encdec_attention_mask, self.kv_dim, self.heads_dim, dropout=hparams.attention_dropout, dropout_broadcast_dims=[self.length_dim], master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="enc_att")) else: if is_incremental: # insert length dimension. x_shape = x.shape shape_with_length = mtf.Shape( x_shape.dims[:-1] + [mtf.Dimension("length", 1)] + x_shape.dims[-1:]) x = mtf.reshape(x, shape_with_length) # ffn layer x += layer_prepostprocess_dropout( self._feedforward_layer(normalize(x), layer_type, losses=losses)) if is_incremental: # remove length dimension x = mtf.reshape(x, x_shape) x = layer_prepostprocess_dropout(normalize(x)) assert not layer_norm_vars if is_incremental: return x, new_self_attention_k, new_self_attention_v else: return x
def __init__(self, config, is_training, input_ids, input_mask=None, token_type_ids=None, scope=None, mesh_shape="", layout=""): self.config = copy.deepcopy(config) del config if not is_training: self.config.layer_output_dropout_prob = 0.0 self.config.attention_probs_dropout_prob = 0.0 self.config.feedforward_intermediate_dropout_prob = 0.0 input_shape = input_ids.shape assert input_shape.ndims == 2 self._seq_dim = input_shape.dims[1] self._memory_seq_dim = mtf.Dimension("memory_seq", self.seq_dim.size) self._extra_losses = [] mesh = input_ids.mesh if token_type_ids is None: token_type_ids = mtf.zeros(mesh, input_shape, dtype=tf.int32) with tf.variable_scope(scope, default_name="bert"): with tf.variable_scope("embeddings"): # Perform embedding lookup on the word ids. self.embedding_table = mtf.get_variable( mesh, "word_embeddings", mtf.Shape([self.vocab_dim, self.model_dim]), initializer=self.embedding_initializer) self.word_embedding_output = mtf.gather( self.embedding_table, input_ids, self.vocab_dim) # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. self.embedding_output = self.word_embedding_output token_type_table = mtf.get_variable( mesh, "token_type_embeddings", mtf.Shape([self.token_type_vocab_dim, self.model_dim]), initializer=self.embedding_initializer) if token_type_ids is not None: self.embedding_output += mtf.gather( token_type_table, token_type_ids, self.token_type_vocab_dim) if self.config.position_signal == "embedding": full_position_table = mtf.get_variable( mesh, "position_embeddings", mtf.Shape( [self.max_position_embeddings_dim, self.model_dim]), initializer=self.embedding_initializer) short_position_table = mtf.rename_dimension( mtf.slice(full_position_table, 0, self.seq_dim.size, self.max_position_embeddings_dim.name), self.max_position_embeddings_dim.name, self.seq_dim.name) self.embedding_output += short_position_table self.embedding_output = self.normalize(self.embedding_output) self.embedding_output = mtf.dropout( self.embedding_output, is_training, keep_prob=1.0 - self.config.layer_output_dropout_prob) with tf.variable_scope("encoder"): attention_biases = [] if input_mask: # [batch_dim, memory_seq_dim] attention_biases.append((1.0 - mtf.to_float( mtf.replace_dimensions(input_mask, self.seq_dim, self.memory_seq_dim))) * -10000.0) if self.config.position_signal == "relative_attention_bias": buckets_dim = mtf.Dimension("buckets", 32) rp_bucket = _relative_position_bucket( mtf.range(mesh, self.memory_seq_dim, tf.int32) - mtf.range(mesh, self.seq_dim, tf.int32), num_buckets=buckets_dim.size) bias_var = mtf.get_variable( mesh, "relative_attention_bias", [self.num_heads_dim, buckets_dim], initializer=tf.zeros_initializer()) attention_biases.append( mtf.gather(bias_var, rp_bucket, buckets_dim)) attention_bias = mtf.add_n(attention_biases) prev_layer_output = self.embedding_output self.all_encoder_layers = [] for block_num in range(self.config.num_blocks): with tf.variable_scope("block_%d" % block_num): for layer_idx, layer_type in enumerate( self.config.block_layers): layer_name = layer_type count = self.config.block_layers[:layer_idx].count( layer_type) if count: layer_name += "_%d" % count with tf.variable_scope(layer_name): x = prev_layer_output if self.config.residual_structure == "direct": x = self.normalize(x) if layer_type == "attention": x = self.self_attention(x, attention_bias) elif layer_type == "feedforward": x = self.feedforward(x) elif layer_type == "moe": x = self.moe(x, layout, mesh_shape, input_mask, is_training) else: raise ValueError("unknown layer type " + layer_type) x = mtf.dropout( x, is_training, keep_prob=1.0 - self.config.layer_output_dropout_prob) layer_output = prev_layer_output + x if self.config.residual_structure == "original": layer_output = self.normalize(layer_output) prev_layer_output = layer_output self.all_encoder_layers.append(layer_output) self.sequence_output = prev_layer_output if self.config.residual_structure == "direct": self.sequence_output = self.normalize(self.sequence_output) # The "pooler" converts the encoded sequence tensor of shape # [batch_dim, seq_dim, hidden_size] to a tensor of shape # [batch_dim, hidden_size]. This is necessary for segment-level # (or segment-pair-level) classification tasks where we need a fixed # dimensional representation of the segment. with tf.variable_scope("pooler"): # We "pool" the model by simply taking the hidden state corresponding # to the first token. We assume that this has been pre-trained first_token_tensor = mtf.gather(self.sequence_output, 0, self.seq_dim) self.pooled_output = mtf.layers.dense( first_token_tensor, reduced_dims=[self.model_dim], new_dims=[self.model_dim], activation=mtf.tanh, kernel_initializer=self.dense_initializer, use_bias=self.config.use_bias)
def bottleneck_block(inputs, filters, is_training, strides, projection_shortcut=None, row_blocks_dim=None, col_blocks_dim=None): """Bottleneck block variant for residual networks with BN after convolutions. Args: inputs: a `mtf.Tensor` of shape `[batch_dim, row_blocks, col_blocks, rows, cols, in_channels]`. filters: `int` number of filters for the first two convolutions. Note that the third and final convolution will use 4 times as many filters. is_training: `bool` for whether the model is in training mode. strides: `int` block stride. If greater than 1, this block will ultimately downsample the input. projection_shortcut: `function` to use for projection shortcuts (typically a 1x1 convolution to match the filter dimensions). If None, no projection is used and the input is passed as unchanged through the shortcut connection. row_blocks_dim: a mtf.Dimension, row dimension which is spatially partitioned along mesh axis col_blocks_dim: a mtf.Dimension, row dimension which is spatially partitioned along mesh axis Returns: The output `Tensor` of the block. """ shortcut = inputs filter_h_dim = mtf.Dimension("filter_height", 3) filter_w_dim = mtf.Dimension("filter_width", 3) one_h_dim = mtf.Dimension("filter_height", 1) one_w_dim = mtf.Dimension("filter_width", 1) if projection_shortcut is not None: filters_dim = mtf.Dimension("filtersp", filters) kernel = mtf.get_variable( inputs.mesh, "kernel", mtf.Shape( [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters_dim])) shortcut = projection_shortcut(inputs, kernel) # First conv block filters1_dim = mtf.Dimension("filters1", filters) kernel1 = mtf.get_variable( inputs.mesh, "kernel1", mtf.Shape( [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters1_dim])) inputs = mtf.conv2d_with_blocks( inputs, kernel1, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) # TODO(nikip): Add Dropout? inputs = batch_norm_relu(inputs, is_training) # Second conv block filters2_dim = mtf.Dimension("filters2", 4*filters) kernel2 = mtf.get_variable( inputs.mesh, "kernel2", mtf.Shape( [filter_h_dim, filter_w_dim, filters1_dim, filters2_dim])) inputs = mtf.conv2d_with_blocks( inputs, kernel2, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim) inputs = batch_norm_relu(inputs, is_training) # Third wide conv filter block filters3_dim = mtf.Dimension("filters3", filters) filters3_kernel = mtf.get_variable( inputs.mesh, "wide_kernel", mtf.Shape( [one_h_dim, one_w_dim, filters2_dim, filters3_dim])) inputs = mtf.conv2d_with_blocks( inputs, filters3_kernel, strides, padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) # TODO(nikip): Althought the original resnet code has this batch norm, in our # setup this is causing no gradients to be passed. Investigate further. # inputs = batch_norm_relu(inputs, is_training, relu=True) # TODO(nikip): Maybe add residual with a projection? return mtf.relu( shortcut + mtf.rename_dimension( inputs, inputs.shape.dims[-1].name, shortcut.shape.dims[-1].name))
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.activation_type # We assume fixed vocab size for targets targets = tf.to_int32(features["targets"]) # Image preprocessing, reshape into a 1D sequence and shift right. length = hparams.img_len * hparams.img_len * hparams.num_channels targets = tf.reshape(targets, [hparams.batch_size, length]) shifted_targets = common_layers.shift_right_2d(targets) # Declare all the dimensions batch_dim = mtf.Dimension("batch", hparams.batch_size) def import_to_batch_by_length(x, name): return mtf.import_tf_tensor(mesh, x, mtf.Shape([batch_dim, self.length_dim]), name=name) def layer_prepostprocess_dropout(x): return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape([batch_dim, self.model_dim])) targets = import_to_batch_by_length(targets, "targets") shifted_targets = import_to_batch_by_length(shifted_targets, "shifted_targets") extra_losses = [] # Create targets content and position embeddings. # Create embedding var for targets and positions and do a gather. targets_embedding_var = mtf.get_variable( mesh, "targets_embedding", mtf.Shape([self.targets_vocab_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=activation_dtype) x = mtf.gather(targets_embedding_var, shifted_targets, self.targets_vocab_dim) # Add positional embeddings x += mtf.reshape(self.create_positional_emb_2d(targets), [self.length_dim, self.model_dim]) # If conditional and input is given, add the input embedding to the target. # TODO(nikip): Verify conditional. if self.has_input and not hparams.unconditional: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = import_to_batch_by_length(inputs, "inputs") # Input embeddings inputs_embedding_var = mtf.layers.embedding( mesh, "input_embedding", mtf.Shape([self.inputs_vocab_dim, self.model_dim]), activation_dtype=activation_dtype) inputs_emb = mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) x += inputs_emb # Image Transformer Decoder # [ self attention - ffn - residual + dropout] x n for layer in range(hparams.num_decoder_layers): layer_name = "decoder_layer_%d" % layer with tf.variable_scope(layer_name): # Self attention layer x += layer_prepostprocess_dropout( mtf.layers.masked_local_attention_1d( mtf.layers.layer_norm(x, self.model_dim, name="layer_norm_att"), None, self.kv_dim, self.heads_dim, block_length=hparams.block_length, name="self_att")) # ffn layer x += layer_prepostprocess_dropout( mtf.layers.dense_relu_dense( mtf.layers.layer_norm(x, self.model_dim, name="layer_norm_ffn"), self.feedforward_dim, hparams.dropout, dropout_broadcast_dims=[self.length_dim])) x = mtf.layers.layer_norm(x, self.model_dim, name="final_layer_norm") # Calculate the logits and loss. logits = mtf.layers.dense(x, self.outputs_vocab_dim, name="logits") soft_targets = mtf.one_hot(targets, self.outputs_vocab_dim, dtype=activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.outputs_vocab_dim) loss = mtf.reduce_mean(loss) for l in extra_losses: loss += l # Reshape logits to original target shape. logits = mtf.reshape( logits, mtf.Shape([ batch_dim, self.rows_dim, self.orig_cols_dim, self.channels_dim, self.outputs_vocab_dim ])) return logits, loss
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) rows_dim = mtf.Dimension("rows_size", 28) cols_dim = mtf.Dimension("cols_size", 28) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 1) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, 28, 28, 1]), mtf.Shape([batch_dim, rows_dim, cols_dim, one_channel_dim])) fh_dim = mtf.Dimension("fh", 3) fw_dim = mtf.Dimension("fw", 3) filters1_dim = mtf.Dimension("filters1", FLAGS.num_filters) filters2_dim = mtf.Dimension("filters2", FLAGS.num_filters) filters3_dim = mtf.Dimension("filters3", FLAGS.num_filters) filters4_dim = mtf.Dimension("filters4", FLAGS.num_filters) filters5_dim = mtf.Dimension("filters5", FLAGS.num_filters) filters6_dim = mtf.Dimension("filters6", FLAGS.num_filters) kernel1 = mtf.get_variable(mesh, "kernel1", [fh_dim, fw_dim, one_channel_dim, filters1_dim]) kernel2 = mtf.get_variable(mesh, "kernel2", [fh_dim, fw_dim, filters1_dim, filters2_dim]) kernel3 = mtf.get_variable(mesh, "kernel3", [fh_dim, fw_dim, filters2_dim, filters3_dim]) kernel4 = mtf.get_variable(mesh, "kernel4", [fh_dim, fw_dim, filters3_dim, filters4_dim]) kernel5 = mtf.get_variable(mesh, "kernel5", [fh_dim, fw_dim, filters4_dim, filters5_dim]) kernel6 = mtf.get_variable(mesh, "kernel6", [fh_dim, fw_dim, filters5_dim, filters6_dim]) x = mtf.relu(mtf.conv2d(x, kernel1, strides=[1, 1, 1, 1], padding="SAME")) x = mtf.relu(mtf.conv2d(x, kernel2, strides=[1, 1, 1, 1], padding="SAME")) x = mtf.relu(mtf.conv2d(x, kernel3, strides=[1, 1, 1, 1], padding="SAME")) x = mtf.relu(mtf.conv2d(x, kernel4, strides=[1, 1, 1, 1], padding="SAME")) x = mtf.relu(mtf.conv2d(x, kernel5, strides=[1, 1, 1, 1], padding="SAME")) x = mtf.relu(mtf.conv2d(x, kernel6, strides=[1, 1, 1, 1], padding="SAME")) x = mtf.reduce_mean(x, reduced_dim=filters6_dim) # add some fully-connected dense layers. hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) logits = mtf.Dimension("logits", 10) h1 = mtf.layers.dense(x, hidden_dim1, reduced_dims=x.shape.dims[-2:], activation=mtf.relu, name="hidden1") h2 = mtf.layers.dense(h1, hidden_dim2, activation=mtf.relu, name="hidden2") logits = mtf.layers.dense(h2, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.set_activation_type() is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN # Declare all the dimensions batch_dim = mtf.Dimension("batch", hparams.batch_size) hidden_dim = mtf.Dimension("hidden", hparams.hidden_size) filter_h_dim = mtf.Dimension("filter_height", 7) filter_w_dim = mtf.Dimension("filter_width", 7) filters = mtf.Dimension("filters", hparams.filter_sizes[0]) rows_dim = mtf.Dimension("rows_size", hparams.rows_size) cols_dim = mtf.Dimension("cols_size", hparams.cols_size) row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks) col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks) classes_dim = mtf.Dimension("classes", 10) channels_dim = mtf.Dimension("channels", 3) one_channel_dim = mtf.Dimension("one_channel", 1) inputs = features["inputs"] x = mtf.import_tf_tensor( mesh, tf.reshape(inputs, [ hparams.batch_size, hparams.row_blocks, hparams.rows_size // hparams.row_blocks, hparams.col_blocks, hparams.num_channels*hparams.cols_size // hparams.col_blocks, hparams.num_channels]), mtf.Shape( [batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, channels_dim])) x = mtf.transpose(x, [batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, channels_dim]) x = mtf.to_float(x) initial_filters = mtf.get_variable( mesh, "init_filters", mtf.Shape([filter_h_dim, filter_w_dim, channels_dim, filters])) x = mtf.conv2d_with_blocks( x, initial_filters, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) x = batch_norm_relu(x, is_training) # Conv blocks # [block - strided block layer - strided block layer] x n for layer in range(hparams.num_layers): layer_name = "block_layer_%d" % layer with tf.variable_scope(layer_name): # Residual block layer x = block_layer( inputs=x, filters=hparams.filter_sizes[0], blocks=hparams.layer_sizes[0], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer1", row_blocks_dim=None, col_blocks_dim=None) x = block_layer( inputs=x, filters=hparams.filter_sizes[1], blocks=hparams.layer_sizes[1], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer2", row_blocks_dim=None, col_blocks_dim=None) x = block_layer( inputs=x, filters=hparams.filter_sizes[2], blocks=hparams.layer_sizes[2], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer3", row_blocks_dim=None, col_blocks_dim=None) # Calculate the logits and loss. out = x outputs = mtf.layers.dense( out, hidden_dim, reduced_dims=out.shape.dims[-5:], activation=mtf.relu, name="dense") # We assume fixed vocab size for targets labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3]) labels = mtf.import_tf_tensor( mesh, tf.reshape(labels, [hparams.batch_size]), mtf.Shape([batch_dim])) logits = mtf.layers.dense(outputs, classes_dim, name="logits") soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, classes_dim) # Reshape logits so it doesn't break inside t2t. logits = mtf.reshape( logits, mtf.Shape([batch_dim, one_channel_dim, classes_dim])) loss = mtf.reduce_mean(loss) return logits, loss
def recon_prototype(mesh, data, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ if dtype == tf.float32: npdtype = "float32" cdtype = tf.complex64 elif dtype == tf.float64: npdtype = "float64" cdtype = tf.complex128 print(dtype, npdtype) # Compute a few things first, using simple tensorflow stages = np.linspace(a0, a, nsteps, endpoint=True) #graph = mtf.Graph() #mesh = mtf.Mesh(graph, "my_mesh") # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition downsampling_factor = 2 lnc = nc // 2**downsampling_factor fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) # Dimensions of the low resolution grid x_dim = mtf.Dimension("nx_lr", lnc) y_dim = mtf.Dimension("ny_lr", lnc) z_dim = mtf.Dimension("nz_lr", lnc) tx_dim = mtf.Dimension("tx_lr", lnc) ty_dim = mtf.Dimension("ty_lr", lnc) tz_dim = mtf.Dimension("tz_lr", lnc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False, dtype=npdtype) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype(npdtype), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype(npdtype), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype(npdtype), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([lnc, lnc, lnc], symmetric=False, dtype=npdtype) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype(npdtype) / 2**downsampling_factor, shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype(npdtype) / 2**downsampling_factor, shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype(npdtype) / 2**downsampling_factor, shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] # kvec for high resolution blocks padded_sx_dim = mtf.Dimension('padded_sx_block', nc // n_block_x + 2 * halo_size) padded_sy_dim = mtf.Dimension('padded_sy_block', nc // n_block_y + 2 * halo_size) padded_sz_dim = mtf.Dimension('padded_sz_block', nc // n_block_z + 2 * halo_size) kvec_hr = flowpm.kernels.fftk([ nc // n_block_x + 2 * halo_size, nc // n_block_y + 2 * halo_size, nc // n_block_z + 2 * halo_size ], symmetric=False, dtype=npdtype) kx_hr = mtf.import_tf_tensor(mesh, kvec_hr[0].squeeze().astype(npdtype), shape=[padded_sx_dim]) ky_hr = mtf.import_tf_tensor(mesh, kvec_hr[1].squeeze().astype(npdtype), shape=[padded_sy_dim]) kz_hr = mtf.import_tf_tensor(mesh, kvec_hr[2].squeeze().astype(npdtype), shape=[padded_sz_dim]) kv_hr = [ky_hr, kz_hr, kx_hr] # kvec for prior blocks prior_sx_dim = mtf.Dimension('prior_sx_block', nc // n_block_x) prior_sy_dim = mtf.Dimension('prior_sy_block', nc // n_block_y) prior_sz_dim = mtf.Dimension('prior_sz_block', nc // n_block_z) kvec_pr = flowpm.kernels.fftk( [nc // n_block_x, nc // n_block_y, nc // n_block_z], symmetric=False, dtype=npdtype) kx_pr = mtf.import_tf_tensor(mesh, kvec_pr[0].squeeze().astype(npdtype), shape=[prior_sx_dim]) ky_pr = mtf.import_tf_tensor(mesh, kvec_pr[1].squeeze().astype(npdtype), shape=[prior_sy_dim]) kz_pr = mtf.import_tf_tensor(mesh, kvec_pr[2].squeeze().astype(npdtype), shape=[prior_sz_dim]) kv_pr = [ky_pr, kz_pr, kx_pr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, x_dim, y_dim, z_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] ## Compute initial initial conditions distributed fieldvar = mtf.get_variable(mesh, 'linear', hr_shape) input_field = tf.placeholder(data.dtype, [ batch_size, n_block_x, n_block_y, n_block_z, nc // n_block_x, nc // n_block_y, nc // n_block_z ]) mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=hr_shape) linearop = mtf.assign(fieldvar, mtfinp) # field = fieldvar initc = mtf.slicewise(lambda x: x[:, 0, 0, 0], [field], output_dtype=tf.float32, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) # for block_size_dim in hr_shape[-3:]: field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]): field = mpm.halo_reduce(field, blocks_dim, block_size_dim, halo_size) field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)]) high = field low = mesh_utils.downsample(field, downsampling_factor, antialias=True) low = mtf.reshape(low, low.shape[:-1]) high = mtf.reshape(high, high.shape[:-1]) for block_size_dim in hr_shape[-3:]: low = mtf.slice(low, halo_size // 2**downsampling_factor, block_size_dim.size // 2**downsampling_factor, block_size_dim.name) # Hack usisng custom reshape because mesh is pretty dumb low = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low], output_dtype=initc.dtype, output_shape=lr_shape, name='my_dumb_reshape', splittable_dims=lr_shape[:-1] + hr_shape[:4]) # Here we can run our nbody if FLAGS.nbody: state = mtfpm.lpt_init(low, high, 0.1, kv_lr, kv_hr, halo_size, hr_shape, lr_shape, part_shape[1:], downsampling_factor=downsampling_factor, antialias=True) final_state = mtfpm.nbody(state, stages, lr_shape, hr_shape, kv_lr, kv_hr, halo_size, downsampling_factor=downsampling_factor) else: final_state = mtfpm.lpt_init(low, high, stages[-1], kv_lr, kv_hr, halo_size, hr_shape, lr_shape, part_shape[1:], downsampling_factor=downsampling_factor, antialias=True) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=dtype, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) mtfdata = mtf.import_tf_tensor(mesh, tf.convert_to_tensor(data), shape=shape) # Get prior k_dims_pr = [d.shape[0] for d in kv_pr] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype) def _cwise_prior(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid( x=kk, x_ref_min=1e-05, x_ref_max=1000.0, y_ref=pk, grid_regularizing_transform=tf.log) priormesh = tf.reshape(pkmesh, kshape) return tf.abs(kfield) / priormesh**0.5 cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv_pr, output_dtype=tf.float32) prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 # Total loss diff = (final_field - mtfdata) R0 = tf.placeholder(tf.float32, shape=()) def _cwise_smooth(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype) return kfield * wts cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype) cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv_pr, output_dtype=cdtype) diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype) chisq = mtf.reduce_sum(mtf.square(diff)) loss = chisq + prior #return initc, final_field, loss, linearop, input_field nyq = np.pi * nc / bs def _cwise_highpass(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype) return kfield * (1 - wts) var_grads = mtf.gradients([loss], [fieldvar]) cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype) cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv_pr, output_dtype=cdtype) var_grads = [ mesh_utils.c2r3d(cgrads, var_grads[0].shape[-3:], dtype=dtype) ] lr = tf.placeholder(tf.float32, shape=()) update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr) return initc, final_field, loss, var_grads, update_op, linearop, input_field, lr, R0
def cifar_model(features, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 32*32] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ features = copy.copy(features) batch_dim = mtf.Dimension("batch", FLAGS.batch_size) row_blocks_dim = mtf.Dimension("row_blocks", 4) col_blocks_dim = mtf.Dimension("col_blocks", 4) rows_dim = mtf.Dimension("rows_size", 8) cols_dim = mtf.Dimension("cols_size", 8) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 3) # image = features['input'] # with tf.device('/cpu:0'): image = features['image'] labels = features['label'] image = bnorm(image) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, 4, 8, 4, 8, 3]), mtf.Shape( [batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim]) # add some convolutional layers to demonstrate that convolution works. fh_dim = mtf.Dimension("fh", 7) fw_dim = mtf.Dimension("fw", 7) filters1_dim = mtf.Dimension("filters1", 32) filters2_dim = mtf.Dimension("filters2", 32) kernel1 = mtf.get_variable( mesh, "kernel1", [fh_dim, fw_dim, one_channel_dim, filters1_dim]) kernel2 = mtf.get_variable( mesh, "kernel2", [fh_dim, fw_dim, filters1_dim, filters2_dim]) f1 = mtf.relu(mtf.conv2d_with_blocks( x, kernel1, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) f2 = mtf.relu(mtf.conv2d_with_blocks( f1, kernel2, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters3_dim = mtf.Dimension("filters3", 64) kernel3 = mtf.get_variable( mesh, "kernel3", [fh_dim, fw_dim, filters2_dim, filters3_dim]) f3 = mtf.relu(mtf.conv2d_with_blocks( f2, kernel3, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters4_dim = mtf.Dimension("filters4", 64) kernel4 = mtf.get_variable( mesh, "kernel4", [fh_dim, fw_dim, filters3_dim, filters4_dim]) f4 = mtf.relu(mtf.conv2d_with_blocks( f3, kernel4, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters5_dim = mtf.Dimension("filters5", 128) kernel5 = mtf.get_variable( mesh, "kernel5", [fh_dim, fw_dim, filters4_dim, filters5_dim]) f5 = mtf.relu(mtf.conv2d_with_blocks( f4, kernel5, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters6_dim = mtf.Dimension("filters6", 128) kernel6 = mtf.get_variable( mesh, "kernel6", [fh_dim, fw_dim, filters5_dim, filters6_dim]) f6 = mtf.relu(mtf.conv2d_with_blocks( f5, kernel6, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters7_dim = mtf.Dimension("filters7", 128) kernel7 = mtf.get_variable( mesh, "kernel7", [fh_dim, fw_dim, filters6_dim, filters7_dim]) f7 = mtf.relu(mtf.conv2d_with_blocks( f6, kernel7, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters8_dim = mtf.Dimension("filters8", 128) kernel8 = mtf.get_variable( mesh, "kernel8", [fh_dim, fw_dim, filters7_dim, filters8_dim]) f8 = mtf.relu(mtf.conv2d_with_blocks( f7, kernel8, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters9_dim = mtf.Dimension("filters9", 128) kernel9 = mtf.get_variable( mesh, "kernel9", [fh_dim, fw_dim, filters8_dim, filters9_dim]) f9 = mtf.relu(mtf.conv2d_with_blocks( f8, kernel9, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters10_dim = mtf.Dimension("filters10", 128) kernel10 = mtf.get_variable( mesh, "kernel10", [fh_dim, fw_dim, filters9_dim, filters10_dim]) f10 = mtf.relu(mtf.conv2d_with_blocks( f9, kernel10, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters11_dim = mtf.Dimension("filters11", 256) kernel11 = mtf.get_variable( mesh, "kernel11", [fh_dim, fw_dim, filters10_dim, filters11_dim]) f11 = mtf.relu(mtf.conv2d_with_blocks( f10, kernel11, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters12_dim = mtf.Dimension("filters12", 256) kernel12 = mtf.get_variable( mesh, "kernel12", [fh_dim, fw_dim, filters11_dim, filters12_dim]) f12 = mtf.relu(mtf.conv2d_with_blocks( f11, kernel12, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters13_dim = mtf.Dimension("filters13", 256) kernel13 = mtf.get_variable( mesh, "kernel13", [fh_dim, fw_dim, filters12_dim, filters13_dim]) f13 = mtf.relu(mtf.conv2d_with_blocks( f12, kernel13, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters14_dim = mtf.Dimension("filters14", 256) kernel14 = mtf.get_variable( mesh, "kernel14", [fh_dim, fw_dim, filters13_dim, filters14_dim]) f14 = mtf.relu(mtf.conv2d_with_blocks( f13, kernel14, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters15_dim = mtf.Dimension("filters15", 256) kernel15 = mtf.get_variable( mesh, "kernel15", [fh_dim, fw_dim, filters14_dim, filters15_dim]) f15 = mtf.relu(mtf.conv2d_with_blocks( f14, kernel15, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters16_dim = mtf.Dimension("filters16", 256) kernel16 = mtf.get_variable( mesh, "kernel16", [fh_dim, fw_dim, filters15_dim, filters16_dim]) f16 = mtf.relu(mtf.conv2d_with_blocks( f15, kernel16, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters17_dim = mtf.Dimension("filters17", 256) kernel17 = mtf.get_variable( mesh, "kernel17", [fh_dim, fw_dim, filters16_dim, filters17_dim]) f17 = mtf.relu(mtf.conv2d_with_blocks( f16, kernel17, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) filters18_dim = mtf.Dimension("filters18", 256) kernel18 = mtf.get_variable( mesh, "kernel18", [fh_dim, fw_dim, filters17_dim, filters18_dim]) f18 = mtf.relu(mtf.conv2d_with_blocks( f17, kernel18, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) x = mtf.reduce_mean(f18, reduced_dim=filters18_dim) # add some fully-connected dense layers. hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) h1 = mtf.layers.dense( x, hidden_dim1, reduced_dims=x.shape.dims[-4:], activation=mtf.relu, name="hidden1") h2 = mtf.layers.dense( h1, hidden_dim2, activation=mtf.relu, name="hidden2") hidden_dim3 = mtf.Dimension("hidden3", FLAGS.hidden_size) hidden_dim4 = mtf.Dimension("hidden4", FLAGS.hidden_size) hidden_dim5 = mtf.Dimension("hidden5", FLAGS.hidden_size) hidden_dim6 = mtf.Dimension("hidden6", FLAGS.hidden_size) hidden_dim7 = mtf.Dimension("hidden7", FLAGS.hidden_size) hidden_dim8 = mtf.Dimension("hidden8", FLAGS.hidden_size) h3 = mtf.layers.dense( h2, hidden_dim3, activation=mtf.relu, name="hidden3") h4 = mtf.layers.dense( h3, hidden_dim4, activation=mtf.relu, name="hidden4") h5 = mtf.layers.dense( h4, hidden_dim5, activation=mtf.relu, name="hidden5") h6 = mtf.layers.dense( h5, hidden_dim6, activation=mtf.relu, name="hidden6") h7 = mtf.layers.dense( h6, hidden_dim7, activation=mtf.relu, name="hidden7") h8 = mtf.layers.dense( h7, hidden_dim8, activation=mtf.relu, name="hidden8") logits = mtf.layers.dense(h8, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor( mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def recon_prototype(mesh, data, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ if dtype == tf.float32: npdtype = "float32" cdtype = tf.complex64 elif dtype == tf.float64: npdtype = "float64" cdtype = tf.complex128 print("Dtype : ", dtype, npdtype) # Compute a few things first, using simple tensorflow kny = 1 * np.pi * nc / bs R1, R2 = 3., 3 * 1.2 stages = np.linspace(a0, a, nsteps, endpoint=True) #graph = mtf.Graph() #mesh = mtf.Mesh(graph, "my_mesh") # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition scalar = mtf.Dimension("scalar", 1) fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) #k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../flowpm/data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] # # Begin simulation ## Compute initial initial conditions distributed #initc = mtfpm.linear_field(mesh, shape, bs, nc, pk, kv) fieldvar = mtf.get_variable(mesh, 'linear', part_shape) input_field = tf.placeholder(data.dtype, [batch_size, nc, nc, nc]) mtfinp = mtf.import_tf_tensor(mesh, input_field, shape=part_shape) linearop = mtf.assign(fieldvar, mtfinp) #field = fieldvar initc = fieldvar print("initc : ", initc) # Here we can run our nbody if FLAGS.nbody: state = mtfpm.lpt_init_single( fieldvar, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # Here we can run our nbody final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr, halo_size) else: final_state = mtfpm.lpt_init_single( initc, stages[-1], kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=dtype, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) ## x = final_field ppars, mpars, kernel = setupfnn() pwts, pbias, pmx, psx = ppars mwts, mbias, mmx, msx, mmy, msy = mpars msy, mmy = msy[0], mmy[0] print("mmy : ", mmy) size = 3 k_dims = [d.shape[0] for d in kv] k_dims = [k_dims[2], k_dims[0], k_dims[1]] tfnc, tfbs = float_to_mtf(nc * 1., mesh, scalar), float_to_mtf(bs, mesh, scalar) x1f = mesh_utils.r2c3d(x, k_dims, dtype=cdtype) x1f = mtf.cwise(cwise_decic, [x1f] + kv + [tfnc, tfbs], output_dtype=cdtype) x1d = mesh_utils.c2r3d(x1f, x.shape[-3:], dtype=dtype) x1d = mtf.add(x1d, -1.) x1f0 = mesh_utils.r2c3d(x1d, k_dims, dtype=cdtype) x1f = mtf.cwise(cwise_fingauss, [x1f0, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) x1 = mesh_utils.c2r3d(x1f, x1d.shape[-3:], dtype=dtype) x2f = mtf.cwise(cwise_fingauss, [x1f0, float_to_mtf(R2, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) x2 = mesh_utils.c2r3d(x2f, x1d.shape[-3:], dtype=dtype) x12 = x1 - x2 width = tf.placeholder(tf.float32, shape=()) def apply_pwts(x, x1, x2): #y = tf.expand_dims(x, axis=-1) y = tf.nn.conv3d(tf.expand_dims(x, axis=-1), kernel, [1, 1, 1, 1, 1], 'SAME') y1 = tf.nn.conv3d(tf.expand_dims(x1, axis=-1), kernel, [1, 1, 1, 1, 1], 'SAME') y2 = tf.nn.conv3d(tf.expand_dims(x2, axis=-1), kernel, [1, 1, 1, 1, 1], 'SAME') #y = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x), -1), kernel, [1, 1, 1, 1, 1], 'VALID') #y1 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x1), -1), kernel, [1, 1, 1, 1, 1], 'VALID') #y2 = tf.nn.conv3d(tf.expand_dims(tfwrap3D(x12), -1), kernel, [1, 1, 1, 1, 1], 'VALID') yy = tf.concat([y, y1, y2], axis=-1) yy = yy - pmx yy = yy / psx yy1 = tf.nn.relu(tf.matmul(yy, pwts[0]) + pbias[0]) yy2 = tf.nn.relu(tf.matmul(yy1, pwts[1]) + pbias[1]) yy3 = tf.matmul(yy2, pwts[2]) + pbias[2] pmodel = tf.nn.sigmoid(width * yy3) return pmodel[..., 0] pmodel = mtf.slicewise( apply_pwts, [x, x1, x12], output_dtype=tf.float32, output_shape=part_shape, # + [mtf.Dimension('c_dim', 81)], name='apply_pwts', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) def apply_mwts(x, x1, x2): #y = tf.expand_dims(x, axis=-1) zz = tf.concat([ tf.expand_dims(x, -1), tf.expand_dims(x1, -1), tf.expand_dims(x2, -1) ], axis=-1) zz = zz - mmx zz = zz / msx zz1 = tf.nn.elu(tf.matmul(zz, mwts[0]) + mbias[0]) zz2 = tf.nn.elu(tf.matmul(zz1, mwts[1]) + mbias[1]) zz3 = tf.matmul(zz2, mwts[2]) + mbias[2] mmodel = zz3 * msy + mmy return mmodel[..., 0] mmodel = mtf.slicewise( apply_mwts, [x, x1, x12], output_dtype=tf.float32, output_shape=part_shape, # + [mtf.Dimension('c_dim', 81)], name='apply_mwts', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) model = pmodel * mmodel mtfdata = mtf.import_tf_tensor(mesh, tf.convert_to_tensor(data), shape=shape) # Get prior #k_dims = [d.shape[0] for d in kv] #k_dims = [k_dims[2], k_dims[0], k_dims[1]] k_dims_pr = [d.shape[0] for d in kv] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype) def _cwise_prior(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid( x=kk, x_ref_min=1e-05, x_ref_max=1000.0, y_ref=pk, grid_regularizing_transform=tf.log) priormesh = tf.reshape(pkmesh, kshape) return tf.abs(kfield) / priormesh**0.5 cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv, output_dtype=tf.float32) prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 * nc**3 # Total loss #diff = (model - mtfdata) modelf = mesh_utils.r2c3d(model, k_dims, dtype=cdtype) modelsmf = mtf.cwise(cwise_fingauss, [modelf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) modelsm = mesh_utils.c2r3d(modelsmf, x1d.shape[-3:], dtype=dtype) #dataf = mesh_utils.r2c3d(mtfdata, k_dims, dtype=cdtype) #datasmf = mtf.cwise(cwise_fingauss, [dataf, float_to_mtf(R1, mesh, scalar)] + kv + [tfnc, tfbs], output_dtype=cdtype) #datasm = mesh_utils.c2r3d(datasmf, x1d.shape[-3:], dtype=dtype) ##Anneal R0 = tf.placeholder(tf.float32, shape=()) M0 = tf.placeholder(tf.float32, shape=()) off, istd = tf.placeholder(tf.float32, shape=data.shape), tf.placeholder( tf.float32, shape=data.shape) mtfoff = mtf.import_tf_tensor(mesh, off, shape=shape) mtfistd = mtf.import_tf_tensor(mesh, istd, shape=shape) diff = mtf.log(modelsm + M0) - mtf.log(mtfdata + M0) #diff = diff / 0.25 #diff = (diff + mtfoff)*mtfistd #For some reason, doing things wrong this one diff = (diff + mtfoff) / 0.25 def _cwise_smooth(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype) return kfield * wts cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype) cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype) diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype) chisq = mtf.reduce_sum(mtf.square(diff)) loss = chisq + prior #return initc, final_field, loss, linearop, input_field nyq = np.pi * nc / bs def _cwise_highpass(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc + 1 / nyq)**2), kfield.dtype) return kfield * (1 - wts) var_grads = mtf.gradients([loss], [fieldvar]) cgrads = mesh_utils.r2c3d(var_grads[0], k_dims_pr, dtype=cdtype) cgrads = mtf.cwise(_cwise_highpass, [cgrads] + kv, output_dtype=cdtype) var_grads = [mesh_utils.c2r3d(cgrads, diff.shape[-3:], dtype=dtype)] lr = tf.placeholder(tf.float32, shape=()) update_op = mtf.assign(fieldvar, fieldvar - var_grads[0] * lr) return initc, model, loss, var_grads, update_op, linearop, input_field, lr, R0, M0, width, chisq, prior, off, istd
def _layer_stack(self, x, layers, encoder_output=None, self_attention_mask=None, encdec_attention_mask=None, losses=None, step_num=None, encdec_tensors=None, states=None): """Encoder or decoder stack. Args: x: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim] layers: an list of strings encoder_output: an optional mtf.Tensor with shape [<batch_dims>, encoder_length_dim, model_dim] self_attention_mask: an optional mtf.Tensor with shape [batch, length_dim, memory_length_dim] containing values 0 or -inf. encdec_attention_mask: an optional mtf.Tensor with shape [batch, length_dim, encoder_length_dim] containing values 0 or -inf. losses: a list to be appended-to step_num: an optional mtf integer Scalar (used in incrmenental mode) encdec_tensors: an optional list of num_layers tuples, each of the form (q_var, o_var, k, v), (used in incremental mode) states: an optional list of Tensors (used in incremental mode) Returns: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim] Raises: ValueError: if hparams make no sense """ hparams = self._hparams is_incremental = (step_num is not None) def layer_prepostprocess_dropout(x): if is_incremental: return x return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape(self.batch_dims + [self.model_dim])) num_layers = len(layers) num_layer_norms = num_layers + 1 layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms) layer_norm_combined_var = mtf.get_variable( x.mesh, "layer_norm_scale", mtf.Shape([layer_norms_dim, self.model_dim]), initializer=tf.ones_initializer(), activation_dtype=x.dtype) layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim) def normalize(x): scale = layer_norm_vars.pop(0) variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim) return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale if is_incremental: states = list(states) new_states = [] tf.logging.info("states = %s" % (states,)) for lnum, layer_type in enumerate(layers): with tf.variable_scope("%s_%d" % (layer_type, lnum)): if layer_type == "att": # Self attention layer if is_incremental: y, new_k, new_v = mtf.layers.multihead_self_attention_incremental( normalize(x), prev_k=states.pop(0), prev_v=states.pop(0), step_num=step_num, master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="att") new_states.append(new_k) new_states.append(new_v) x += y else: x += layer_prepostprocess_dropout( mtf.layers.multihead_attention( normalize(x), None, self_attention_mask, self.kv_dim, self.heads_dim, dropout=hparams.attention_dropout, dropout_broadcast_dims=[self.length_dim], master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="att")) elif layer_type == "enc_att": # Encoder-Decoder attention layer if is_incremental: # Encoder-Decoder attention layer q_var, o_var, k, v = encdec_tensors[lnum] x += mtf.layers.multihead_encdec_attention_incremental( normalize(x), q_var, o_var, k, v, encdec_attention_mask, name="enc_att") else: x += layer_prepostprocess_dropout( mtf.layers.multihead_attention( normalize(x), encoder_output, encdec_attention_mask, self.kv_dim, self.heads_dim, dropout=hparams.attention_dropout, dropout_broadcast_dims=[self.length_dim], master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="enc_att")) elif layer_type == "local_att": if is_incremental: y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental( normalize(x), prev_k=states.pop(0), prev_v=states.pop(0), step_num=step_num, master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="local_att") new_states.append(new_k) new_states.append(new_v) x += y else: x += layer_prepostprocess_dropout( mtf.layers.masked_local_attention_1d( normalize(x), self.kv_dim, self.heads_dim, window_size=hparams.local_attention_window_size, master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, length_per_split=mtf.tensor_dim_to_size_per_split( hparams.layout, hparams.mesh_shape, self.max_length_dim), name="local_att")) elif layer_type == "compressed_att": if is_incremental: raise ValueError("compressed_att incremental not implemented") else: x += layer_prepostprocess_dropout( mtf.layers.multihead_self_attention_memory_compressed( normalize(x), mask_right=True, compression_factor=hparams.compression_factor, kv_channels=self.kv_dim, heads=self.heads_dim, dropout=hparams.attention_dropout, dropout_broadcast_dims=[self.length_dim], master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="compressed_att")) else: if is_incremental: # insert length dimension. x_shape = x.shape shape_with_length = mtf.Shape( x_shape.dims[:-1] + [mtf.Dimension("length", 1)] + x_shape.dims[-1:]) x = mtf.reshape(x, shape_with_length) # ffn layer x += layer_prepostprocess_dropout( self._feedforward_layer(normalize(x), layer_type, losses=losses)) if is_incremental: # remove length dimension x = mtf.reshape(x, x_shape) x = layer_prepostprocess_dropout(normalize(x)) assert not layer_norm_vars if is_incremental: return x, new_states else: return x
def _var(x, init): return mtf.get_variable(x.mesh, f"activation-{random.randint(0, 2 ** 32):x}", [], initializer=tf.constant_initializer(init), dtype=x.dtype)
def bottleneck_block(inputs, filters, is_training, strides, projection_shortcut=None, row_blocks_dim=None, col_blocks_dim=None): """Bottleneck block variant for residual networks with BN after convolutions. Args: inputs: a `mtf.Tensor` of shape `[batch_dim, row_blocks, col_blocks, rows, cols, in_channels]`. filters: `int` number of filters for the first two convolutions. Note that the third and final convolution will use 4 times as many filters. is_training: `bool` for whether the model is in training mode. strides: `int` block stride. If greater than 1, this block will ultimately downsample the input. projection_shortcut: `function` to use for projection shortcuts (typically a 1x1 convolution to match the filter dimensions). If None, no projection is used and the input is passed as unchanged through the shortcut connection. row_blocks_dim: a mtf.Dimension, row dimension which is spatially partitioned along mesh axis col_blocks_dim: a mtf.Dimension, row dimension which is spatially partitioned along mesh axis Returns: The output `Tensor` of the block. """ shortcut = inputs filter_h_dim = mtf.Dimension("filter_height", 3) filter_w_dim = mtf.Dimension("filter_width", 3) one_h_dim = mtf.Dimension("filter_height", 1) one_w_dim = mtf.Dimension("filter_width", 1) if projection_shortcut is not None: filters_dim = mtf.Dimension("filtersp", filters) kernel = mtf.get_variable( inputs.mesh, "kernel", mtf.Shape( [one_h_dim, one_w_dim, inputs.shape.dims[-1], filters_dim])) shortcut = projection_shortcut(inputs, kernel) # First conv block filters1_dim = mtf.Dimension("filters1", filters) kernel1 = mtf.get_variable( inputs.mesh, "kernel1", mtf.Shape([one_h_dim, one_w_dim, inputs.shape.dims[-1], filters1_dim])) inputs = mtf.conv2d_with_blocks(inputs, kernel1, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) # TODO(nikip): Add Dropout? inputs = batch_norm_relu(inputs, is_training) # Second conv block filters2_dim = mtf.Dimension("filters2", filters) kernel2 = mtf.get_variable( inputs.mesh, "kernel2", mtf.Shape([filter_h_dim, filter_w_dim, filters1_dim, filters2_dim])) inputs = mtf.conv2d_with_blocks(inputs, kernel2, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim) inputs = batch_norm_relu(inputs, is_training) # Third wide conv filter block filters3_dim = mtf.Dimension("filters3", filters) filters3_kernel = mtf.get_variable( inputs.mesh, "wide_kernel", mtf.Shape([one_h_dim, one_w_dim, filters2_dim, filters3_dim])) inputs = mtf.conv2d_with_blocks(inputs, filters3_kernel, strides, padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) inputs = batch_norm_relu(inputs, is_training, relu=False) # TODO(nikip): Maybe add residual with a projection? return mtf.relu(inputs + mtf.rename_dimension( shortcut, shortcut.shape.dims[-1].name, inputs.shape.dims[-1].name))
def _layer_stack(self, x, num_layers, encoder_output=None, self_attention_mask=None, encdec_attention_mask=None, losses=None): """Encoder or decoder stack. Args: x: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim] num_layers: an integer encoder_output: an optional mtf.Tensor with shape [<batch_dims>, encoder_length_dim, model_dim] self_attention_mask: an optional mtf.Tensor with shape [batch, length_dim, memory_length_dim] containing values 0 or -inf. encdec_attention_mask: an optional mtf.Tensor with shape [batch, length_dim, encoder_length_dim] containing values 0 or -inf. losses: a list to be appended-to Returns: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim] Raises: ValueError: if hparams make no sense """ hparams = self._hparams def layer_prepostprocess_dropout(x): return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape(self.batch_dims + [self.model_dim])) num_layer_norms = num_layers * (2 if encoder_output is None else 3) + 1 layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms) layer_norm_combined_var = mtf.get_variable( x.mesh, "layer_norm_scale", mtf.Shape([layer_norms_dim, self.model_dim]), initializer=tf.ones_initializer(), activation_dtype=x.dtype) layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim) def normalize(x): scale = layer_norm_vars.pop(0) variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim) return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale for layer in range(num_layers): with tf.variable_scope("layer_%d" % layer): # Self attention layer x += layer_prepostprocess_dropout( mtf.layers.multihead_attention( normalize(x), None, self_attention_mask, self.kv_dim, self.heads_dim, dropout=hparams.attention_dropout, dropout_broadcast_dims=[self.length_dim], name="self_attention")) if encoder_output is not None: # Encoder-Decoder attention layer x += layer_prepostprocess_dropout( mtf.layers.multihead_attention( normalize(x), encoder_output, encdec_attention_mask, self.kv_dim, self.heads_dim, dropout=hparams.attention_dropout, dropout_broadcast_dims=[self.length_dim], name="encdec_attention")) # ffn layer x += layer_prepostprocess_dropout( self._feedforward_layer(normalize(x), losses=losses)) x = layer_prepostprocess_dropout(normalize(x)) assert not layer_norm_vars return x
def _decoder_layer_stack_incremental(self, x, step_num, encdec_tensors, self_attention_k, self_attention_v, encdec_attention_mask=None): """Decoder layer stack during inference. We are processing only one position at a time. The self-attention keys and values have already been computed for previous positions. In addition to the decoder output, we need to produce the updated self-attention keys and values. If there is an encoder, then additional Tensors are supplied in encdec_tensors, which give us the keys and values for encoder-decoder attention as well as the weight matrices q_var and o_var. Args: x: a mtf.Tensor with shape [<batch_dims>, model_dim] step_num: an mtf integer Scalar encdec_tensors: an optional list of num_layers tuples, each of the form (q_var, o_var, k, v) self_attention_k: an optional list of num_layers Tensors each with shape [batch, heads, memory_length, kv_channels] self_attention_v: an optional list of num_layers Tensors each with shape [batch, heads, memory_length, kv_channels] encdec_attention_mask: an optional mtf.Tensor with shape [batch, length_dim, encoder_length_dim] containing values 0 or -inf. Returns: y: a mtf.Tensor with shape [<batch_dims>, model_dim] new_self_attention_k: a list of num_layers mtf.Tensors, with the same shapes as the elements of self_attention_k new_self_attention_v: a list of num_layers mtf.Tensors, with the same shapes as the elements of self_attention_v Raises: ValueError: if hparams make no sense """ hparams = self._hparams num_layers = hparams.num_decoder_layers num_layer_norms = num_layers * (2 if encdec_tensors is None else 3) + 1 layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms) layer_norm_combined_var = mtf.get_variable( x.mesh, "layer_norm_scale", mtf.Shape([layer_norms_dim, self.model_dim]), initializer=tf.ones_initializer(), activation_dtype=x.dtype) layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim) def normalize(x): scale = layer_norm_vars.pop(0) variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim) return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale new_self_attention_k = [] new_self_attention_v = [] for layer in xrange(num_layers): with tf.variable_scope("layer_%d" % layer): # Self attention layer y, new_k, new_v = mtf.layers.multihead_self_attention_incremental( normalize(x), prev_k=self_attention_k[layer], prev_v=self_attention_v[layer], step_num=step_num, master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="att") new_self_attention_k.append(new_k) new_self_attention_v.append(new_v) x += y if encdec_tensors is not None: # Encoder-Decoder attention layer q_var, o_var, k, v = encdec_tensors[layer] x += mtf.layers.multihead_encdec_attention_incremental( normalize(x), q_var, o_var, k, v, encdec_attention_mask, name="enc_att") # ffn layer x += self._feedforward_layer(normalize(x), layer) x = normalize(x) assert not layer_norm_vars return x, new_self_attention_k, new_self_attention_v
def recon_model(mesh, data, R0, x0, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ if dtype == tf.float32: npdtype = "float32" cdtype = tf.complex64 elif dtype == tf.float64: npdtype = "float64" cdtype = tf.complex128 print(dtype, npdtype) # Compute a few things first, using simple tensorflow stages = np.linspace(a0, a, nsteps, endpoint=True) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('../data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('../data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] # Begin simulation if x0 is None: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.random_normal_initializer( mean=0.0, stddev=1, seed=None)) else: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.constant_initializer(x0)) print("\nfieldvar : \n", fieldvar) # Here we can run our nbody if FLAGS.nbody: state = mtfpm.lpt_init_single( fieldvar, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # Here we can run our nbody final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr, halo_size) else: final_state = mtfpm.lpt_init_single( fieldvar, stages[-1], kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) # paint the field final_field = mtf.zeros(mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: final_field = mtf.pad(final_field, [halo_size, halo_size], block_size_dim.name) final_field = mesh_utils.cic_paint(final_field, final_state[0], halo_size) # Halo exchange for blocks_dim, block_size_dim in zip(hr_shape[1:4], final_field.shape[-3:]): final_field = mpm.halo_reduce(final_field, blocks_dim, block_size_dim, halo_size) # Remove borders for block_size_dim in hr_shape[-3:]: final_field = mtf.slice(final_field, halo_size, block_size_dim.size, block_size_dim.name) final_field = mtf.slicewise( lambda x: x[:, 0, 0, 0], [final_field], output_dtype=dtype, output_shape=[batch_dim, fx_dim, fy_dim, fz_dim], name='my_dumb_reshape', splittable_dims=part_shape[:-1] + hr_shape[:4]) mtfdata = mtf.import_tf_tensor(mesh, tf.convert_to_tensor(data), shape=shape) # Get prior k_dims_pr = [d.shape[0] for d in kv] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype) def _cwise_prior(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid( x=kk, x_ref_min=1e-05, x_ref_max=1000.0, y_ref=pk, grid_regularizing_transform=tf.log) priormesh = tf.reshape(pkmesh, kshape) return tf.abs(kfield) / priormesh**0.5 cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv, output_dtype=tf.float32) prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 #*nc**3 # Total loss diff = (final_field - mtfdata) R0 = tf.constant(R0) print("R0 in the recon_model : ", R0) def _cwise_smooth(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype) return kfield * wts # Element-wise function that applies a Fourier kernel plambda = FLAGS.plambda def _cwise_logprob(finalfield, data): galmean = tfp.distributions.Poisson(rate=plambda * (1 + finalfield)) logprob = galmean.log_prob(data) return -1 * logprob cfield = mesh_utils.r2c3d(final_field, k_dims_pr, dtype=cdtype) cfield = mtf.cwise(_cwise_smooth, [cfield] + kv, output_dtype=cdtype) final_fieldsm = mesh_utils.c2r3d(cfield, diff.shape[-3:], dtype=dtype) chisq = mtf.cwise(_cwise_logprob, [final_fieldsm, mtfdata], output_dtype=tf.float32) # chisq = mtf.reduce_sum(chisq) ## # loss = chisq + prior def _cwise_sample(finalfield, data): galmean = tfp.distributions.Poisson(rate=plambda * (1 + finalfield)) sample = galmean.sample() return sample sample = mtf.cwise(_cwise_sample, [final_fieldsm, mtfdata], output_dtype=tf.float32) # fields = [fieldvar, sample] metrics = [chisq, prior, loss] return fields, metrics, kv
def gradient_based_subword_tokenization(x, length_dim, max_subword_length=4, downsample=None, use_offsets=False, consider_chars_as_blocks=False, use_block_pos_embedding=False, share_block_kernel=False, memory_embeddings=0, context=None, block_mixing_mode=None, activation="softmax", downsample_function="mean"): """Implements GBSWT from Charformer. Args: x: a Tensor containing length_dim length_dim: a Dimension max_subword_length: integer downsample: integer. use_offsets: boolean. consider_chars_as_blocks: boolean. use_block_pos_embedding: boolean. share_block_kernel: boolean. memory_embeddings: integer. context: Context. block_mixing_mode: Str for block mixing. activation: Str for block ranking. downsample_function: Str, supports mean/linformer for now. Returns: a Tensor with the same shape as x. Raises: ValueError: if channels or depth don't match. """ # don't use this for now. del max_subword_length del memory_embeddings all_blocks = [] all_scores = [] tf.logging.info("GSW block layer") def _tile(x, n, tile_dim): # Simple tile function in MTF. return mtf.concat([x] * n, tile_dim.name) def _repeat(x, n, repeat_dim): # repeat function in MTF tmp_dim = mtf.Dimension("tmp", 1) expand_shape = mtf.Shape(x.shape.dims + [tmp_dim]) x = mtf.reshape(x, expand_shape) x = _tile(x, n, tmp_dim) output_shape = [] for dim in x.shape.dims: if dim.name == "tmp": continue if dim.name == repeat_dim.name: dim = mtf.Dimension(dim.name, dim.size * n) output_shape.append(dim) output_shape = mtf.Shape(output_shape) x = mtf.reshape(x, output_shape) return x def _combined_dim(dims): return mtf.Dimension(dims[0].name, mtf.Shape(dims).size) # compute all subword blocks # TODO(yitay): handle offsets to get all blocks if activation == "sigtanh": # one score for sigmoid tmp_dim = mtf.Dimension("block_score", 2) else: tmp_dim = mtf.Dimension("block_score", 1) model_dim = x.shape[-1] subword_blocks_width = [2, 3, 4] if consider_chars_as_blocks: subword_blocks_width += [1] if share_block_kernel: block_kernel_shape = mtf.Shape([model_dim, tmp_dim]) block_kernel = mtf.get_variable(x.mesh, "block_kernel", block_kernel_shape, initializer=None, dtype=context.variable_dtype) else: block_kernel = None for subword_len in subword_blocks_width: if use_block_pos_embedding: # this is turn off by default. It is meant to support cases like # parameterized pooling or other features. block_len_dim = mtf.Dimension(length_dim.name, subword_len) # TODO(vqtran): Consider other positional embeddings. block_pos_emb = sinusoid_positional_embedding_weights( context.mesh, block_len_dim, x.shape[-1], context.variable_dtype.activation_dtype) block_pos_emb = _repeat( block_pos_emb, math.ceil(length_dim.size / float(subword_len)), block_len_dim) if use_offsets: offset_space = subword_len else: offset_space = 1 for offsets in range(offset_space): if offsets > 0: xoff = mtf.shift(x, offsets, length_dim, wrap=False) if use_block_pos_embedding: block_pos_emb = mtf.shift(block_pos_emb, offsets, block_pos_emb.shape[-2], wrap=False) else: xoff = x tf.logging.info("SW len=%d offset=%d", subword_len, offsets) if length_dim.size % subword_len != 0: tf.logging.info("Not divisible by length") # add extra padding tokens pad_amt = int(subword_len) - int(length_dim.size % subword_len) kp = mtf.pad(xoff, [0, pad_amt], length_dim.name) else: kp = xoff if use_block_pos_embedding: kp += block_pos_emb bx = mtf.pool_tensor_1d( kp, pool_dim=kp.shape.get_dim_by_name("length"), reduce_fn=mtf.reduce_mean, pool_size=int(subword_len)) block_score = mtf.layers.dense(bx, [tmp_dim], use_bias=False, name="bx", reduced_dims=[model_dim], variable_dtype=None, kernel_weights=block_kernel) expand_bx = _repeat(bx, subword_len, length_dim) expand_scores = _repeat(block_score, subword_len, length_dim) if offsets > 0: # add offset. expand_bx = mtf.pad(expand_bx, [offsets, 0], length_dim.name) expand_scores = mtf.pad(expand_scores, [offsets, 0], length_dim.name) new_len = expand_bx.shape.get_dim_by_name(length_dim.name) if new_len.size < length_dim.size: pad_amt = new_len.size - length_dim.size expand_bx = mtf.pad(expand_bx, [0, pad_amt], length_dim.name) expand_scores = mtf.pad(expand_scores, [0, pad_amt], length_dim.name) elif new_len.size > length_dim.size: expand_bx = mtf.slice(expand_bx, 0, length_dim.size, length_dim.name) expand_scores = mtf.slice(expand_scores, 0, length_dim.size, length_dim.name) new_tmp_dim = mtf.Dimension("extra_dim", 1) expand_shape = mtf.Shape(expand_bx.shape.dims + [new_tmp_dim]) expand_scores_shape = mtf.Shape(expand_scores.shape.dims + [new_tmp_dim]) expand_bx = mtf.reshape(expand_bx, expand_shape) expand_scores = mtf.reshape(expand_scores, expand_scores_shape) all_blocks.append(expand_bx) all_scores.append(expand_scores) all_blocks = mtf.concat(all_blocks, new_tmp_dim.name) all_scores = mtf.concat(all_scores, new_tmp_dim.name) tf.logging.info(all_blocks) new_tmp_dim = all_blocks.shape.get_dim_by_name("extra_dim") combined_dim = _combined_dim([new_tmp_dim, tmp_dim]) block_net_shape = all_scores.shape - tmp_dim - new_tmp_dim + combined_dim block_net = mtf.reshape(all_scores, block_net_shape) if block_mixing_mode == "score_attention": tf.logging.info("Using score attention") att = mtf.einsum([block_net, block_net], reduced_dims=[new_tmp_dim]) tf.logging.info(block_net) att = mtf.softmax(att, reduced_dim=att.shape[-1]) block_net = mtf.einsum([att, block_net], output_shape=block_net.shape) tf.logging.info(block_net) if activation == "softmax": block_net = mtf.softmax(block_net, reduced_dim=new_tmp_dim) elif activation == "tanh": tf.logging.info("Using tanh") block_net = mtf.tanh(block_net) all_blocks = block_net * all_blocks all_blocks = mtf.reduce_sum(all_blocks, reduced_dim=new_tmp_dim) output = all_blocks if downsample: output_length = output.shape.get_dim_by_name("length") if output_length.size % int(downsample) != 0: pad_amt = int(downsample) - int( output_length.size % int(downsample)) output = mtf.pad(output, [0, pad_amt], output_length.name) if downsample_function == "mean": output = mtf.pool_tensor_1d( output, pool_dim=output.shape.get_dim_by_name("length"), reduce_fn=mtf.reduce_mean, pool_size=int(downsample)) else: raise ValueError("Downsampling function not implemeneted.") return output
def __init__(self, mesh, query_input_dim, memory_input_dim, output_dim, key_dim, value_dim, query_heads_dims, memory_heads_dims, variable_dtype, shared_kv=False, combine_dims=True, ensemble_dim=None): """Create attention parameters. combine_dims is a hack for faster execution. The heads and key/value dimensions are combined in the variables and the computation. The hack would not be necessary if XLA optimized einsum properly. Args: mesh: a Mesh query_input_dim: a Dimension memory_input_dim: a Dimension output_dim: a Dimension key_dim: a Dimension value_dim: a Dimension query_heads_dims: a list of Dimension memory_heads_dims: a list of Dimension variable_dtype: a mtf.VariableDType shared_kv: a boolean combine_dims: a boolean ensemble_dim: an optional Dimension """ if shared_kv and key_dim != value_dim: raise ValueError("shared_kv requires key_dim == value_dim") self.query_input_dim = query_input_dim self.memory_input_dim = memory_input_dim self.output_dim = output_dim self.key_dim = key_dim self.value_dim = value_dim self.query_heads_dims = query_heads_dims or [] self.memory_heads_dims = memory_heads_dims or [] self.shared_kv = shared_kv self.combine_dims = combine_dims if combine_dims: q_shape = [query_input_dim, _combined_dim(self.q_dims)] k_shape = [memory_input_dim, _combined_dim(self.k_dims)] v_shape = [memory_input_dim, _combined_dim(self.v_dims)] o_shape = [_combined_dim(self.o_dims), output_dim] else: q_shape = [query_input_dim] + self.q_dims k_shape = [memory_input_dim] + self.k_dims v_shape = [memory_input_dim] + self.v_dims o_shape = self.o_dims + [output_dim] q_init = tf.random_normal_initializer(stddev=(query_input_dim.size * key_dim.size)**-0.5) kv_init = tf.random_normal_initializer( stddev=memory_input_dim.size**-0.5) o_init = tf.random_normal_initializer( stddev=mtf.Shape(self.query_heads_dims + [value_dim]).size**-0.5) if ensemble_dim: q_shape = [ensemble_dim] + q_shape k_shape = [ensemble_dim] + k_shape v_shape = [ensemble_dim] + v_shape o_shape = [ensemble_dim] + o_shape self.wq = mtf.get_variable(mesh, "q", q_shape, initializer=q_init, dtype=variable_dtype) if shared_kv: self.wkv = mtf.get_variable(mesh, "kv", k_shape, initializer=kv_init, dtype=variable_dtype) else: self.wk = mtf.get_variable(mesh, "k", k_shape, initializer=kv_init, dtype=variable_dtype) self.wv = mtf.get_variable(mesh, "v", v_shape, initializer=kv_init, dtype=variable_dtype) self.wo = mtf.get_variable(mesh, "o", o_shape, initializer=o_init, dtype=variable_dtype)
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.activation_type # We assume fixed vocab size for targets targets = tf.to_int32(features["targets"]) # Image preprocessing, reshape into a 1D sequence and shift right. length = hparams.img_len * hparams.img_len * hparams.num_channels targets = tf.reshape(targets, [hparams.batch_size, length]) shifted_targets = common_layers.shift_right_2d(targets) # Declare all the dimensions batch_dim = mtf.Dimension("batch", hparams.batch_size) def import_to_batch_by_length(x, name): return mtf.import_tf_tensor(mesh, x, mtf.Shape([batch_dim, self.length_dim]), name=name) targets = import_to_batch_by_length(targets, "targets") shifted_targets = import_to_batch_by_length(shifted_targets, "shifted_targets") extra_losses = [] # Create targets content and position embeddings. # Create embedding var for targets and positions and do a gather. targets_embedding_var = mtf.get_variable( mesh, "targets_embedding", mtf.Shape([self.targets_vocab_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=activation_dtype) x = mtf.gather(targets_embedding_var, shifted_targets, self.targets_vocab_dim) # Add positional embeddings x += mtf.reshape(self.create_positional_emb_2d(targets), [self.length_dim, self.model_dim]) # If conditional and input is given, add the input embedding to the target. # TODO(nikip): Verify conditional. if self.has_input and not hparams.unconditional: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = import_to_batch_by_length(inputs, "inputs") # Input embeddings inputs_embedding_var = mtf.layers.embedding( mesh, "input_embedding", mtf.Shape([self.inputs_vocab_dim, self.model_dim]), activation_dtype=activation_dtype) inputs_emb = mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) x += inputs_emb # Image Transformer Decoder # [ self attention - ffn - residual + dropout] x n if hparams.attention_type == "local1d_spatial": decoder_output = local_attention1d_spatial_decoder( x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) elif hparams.attention_type == "local2d_spatial": decoder_output = local_attention2d_spatial_decoder( x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) elif hparams.attention_type == "local1d": decoder_output = local_attention1d_masked_decoder( x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) else: raise ValueError("Invalid attention type.") # Calculate the logits and loss. logits = mtf.layers.dense(decoder_output, self.outputs_vocab_dim, name="logits") # Need a reshape for logits logits = mtf.reshape( logits, mtf.Shape([batch_dim, self.length_dim, self.outputs_vocab_dim])) soft_targets = mtf.one_hot(targets, self.outputs_vocab_dim, dtype=activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.outputs_vocab_dim) loss = mtf.reduce_mean(loss) for l in extra_losses: loss += l # Reshape logits to original target shape. logits = mtf.reshape( logits, mtf.Shape([ batch_dim, self.rows_dim, self.orig_cols_dim, self.channels_dim, self.outputs_vocab_dim ])) return logits, loss
def rezero(x, scope, dtype): with tf.variable_scope(scope): g = mtf.get_variable(x.mesh, "g", [], initializer=tf.constant_initializer(0), dtype=dtype) return x * g
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a tf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) row_blocks_dim = mtf.Dimension("row_blocks", 4) col_blocks_dim = mtf.Dimension("col_blocks", 4) rows_dim = mtf.Dimension("rows_size", 7) cols_dim = mtf.Dimension("cols_size", 7) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 1) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, 4, 7, 4, 7, 1]), mtf.Shape([ batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim ])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim ]) # add some convolutional layers to demonstrate that convolution works. fh_dim = mtf.Dimension("fh", 9) fw_dim = mtf.Dimension("fw", 9) filters1_dim = mtf.Dimension("filters1", 16) filters2_dim = mtf.Dimension("filters2", 16) kernel1 = mtf.get_variable(mesh, "kernel1", [fh_dim, fw_dim, one_channel_dim, filters1_dim]) kernel2 = mtf.get_variable(mesh, "kernel2", [fh_dim, fw_dim, filters1_dim, filters2_dim]) f1 = mtf.relu( mtf.conv2d_with_blocks(x, kernel1, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) f2 = mtf.relu( mtf.conv2d_with_blocks(f1, kernel2, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=row_blocks_dim, w_blocks_dim=col_blocks_dim)) x = mtf.reduce_mean(f2, reduced_dim=filters2_dim) # add some fully-connected dense layers. hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) h1 = mtf.layers.dense(x, hidden_dim1, reduced_dims=x.shape.dims[-4:], activation=mtf.relu, name="hidden1") h2 = mtf.layers.dense(h1, hidden_dim2, activation=mtf.relu, name="hidden2") logits = mtf.layers.dense(h2, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def talking_heads(self, context, inp, name, input_heads_dims, output_heads_dims, dynamic_projections_from=None): shared_dims = [d for d in input_heads_dims if d in output_heads_dims] reduced_dims = [ d for d in input_heads_dims if d not in output_heads_dims ] new_dims = [d for d in output_heads_dims if d not in input_heads_dims] if not (reduced_dims or new_dims): # Output dimensions are same as input dimensions. Return the input return inp elif dynamic_projections_from: # There are one or more dynamic talking-heads-projections with tf.variable_scope(name): # static projection - this is the same as the static projection in the # "else" case below. We create the weight matrix with get_variable # instead of calling mtf.layers.dense() so that we can fold the # static projection into one of the dynamic projections. static_p_initializer = mtf.layers.VarianceScalingInitializer()( reduced_dims, new_dims) static_p_shape = (context.model.ensemble_dims + shared_dims + reduced_dims + new_dims) static_p = mtf.get_variable(inp.mesh, "kernel", static_p_shape, initializer=static_p_initializer, dtype=context.variable_dtype) ps = [] for i, dp_from in enumerate(dynamic_projections_from): kernel_initializer = mtf.layers.VarianceScalingInitializer( self.dynamic_projections_init_scale / mtf.Shape(reduced_dims).size) ps.append( mtf.layers.dense( dp_from, reduced_dims=[context.model.model_dim], new_dims=shared_dims + reduced_dims + new_dims, use_bias=False, activation=None, variable_dtype=context.variable_dtype, name="%s_dynamic_%d" % (name, i), expert_dims=context.model.ensemble_dims, kernel_initializer=kernel_initializer)) # Fold the static projection into one of the static projections. # Mathematically, we could add all the dynamic projections together # here, but it would create a very large tensor which contained # both the query-length and memory-length dimensions, and would # probably be slower in practice. ps[0] += static_p return mtf.add_n([ mtf.einsum([inp, p], reduced_dims=reduced_dims) for p in ps ]) else: # No dynamic projections. Static talking-heads projection only return mtf.layers.dense(inp, reduced_dims=reduced_dims, new_dims=new_dims, use_bias=False, activation=None, variable_dtype=context.variable_dtype, name=name, expert_dims=context.model.ensemble_dims + shared_dims)
def attn(x, scope, n_state, *, attention_type, params, bias, dim_seq, memory_length_dim, variable_dtype, context=None): # x :: [batch, seq, n_embd] x_shape, dim_batch, *_, dim_embd, mesh = x.shape, *x.shape, x.mesh # n_state is the same as config["n_embd"], which is also the same as dim_embd. assert n_state.size % params["n_head"] == 0 dim_heads = mtf.Dimension("heads", params["n_head"]) num_mem_kv = params.get("num_mem_kv", 0) use_num_mem_kv = num_mem_kv > 0 with tf.variable_scope(scope): # Compute attention inputs dim_kv = mtf.Dimension("features_per_head", params["n_embd"] // params["n_head"]) mtfparams = mtf.transformer.attention.attention_params_simple( x.mesh, io_dim=dim_embd, kv_dim=dim_kv, heads_dim=dim_heads, variable_dtype=variable_dtype ) q = mtfparams.compute_q(x) k = mtfparams.compute_k(x) v = mtfparams.compute_v(x) if is_incremental_inference(context): one_hot = mtf.one_hot(context.position - 1, dim_seq, dtype=variable_dtype.master_dtype) inv_one_hot = 1.0 - one_hot old_k, old_v = context.get_states(2) k = old_k * inv_one_hot + k * one_hot v = old_v * inv_one_hot + v * one_hot if exists(context): context.record_new_states([k, v]) with tf.variable_scope("attention"): if attention_type == "local": # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights. radius = params.get("local_attention_radius", 256) if is_incremental_inference(context): q *= one_hot a = mtf_transformer.attention.local_attention_1d( q, k, v, length_dim=k.shape[1], key_dim=dim_kv, value_dim=dim_kv, radius=radius, length_dim_num_splits=1, fully_autoregressive=params["causal"], attention_kwargs={}, ) if is_incremental_inference(context): a = mtf.gather(a, context.position - 1, dim_seq) elif attention_type == "global": # TODO: pass in fake context # Broadcast mask bias across batch and heads if exists(bias): if not is_incremental_inference(context): broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-2], bias.shape[-1]]) else: # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position bias = mtf.gather(bias, context.position - 1, dim_seq) broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-1]]) # memory key / values, from all-attention paper if use_num_mem_kv: k, v = memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh) k = mtf.replace_dimensions(k, k.shape[1], memory_length_dim) v = mtf.replace_dimensions(v, v.shape[1], memory_length_dim) attn_dropout_rate = params["attn_dropout"] if params["mode"] == "train" else 0 a = mtf_transformer.attention.attention( q, k, v, memory_length_dim=memory_length_dim, key_dim=dim_kv, value_dim=dim_kv, bias=broadcasted_bias, dropout_rate=attn_dropout_rate ) elif attention_type == "linear": linear_attn_fn = causal_linear_attention if params["causal"] else linear_attention a = linear_attn_fn(q, k, v) else: raise NotImplementedError("Unknown attention type {}!".format(attention_type)) with tf.variable_scope("compute_output"): a = mtfparams.compute_output(a, x_shape) with tf.variable_scope("compute_output_bias"): b = mtf.get_variable(x.mesh, "o_b", [dim_embd], initializer=tf.constant_initializer(0), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) a += b if params["mode"] == "train" and params["res_dropout"] > 0: a = mtf.dropout(a, rate=params["res_dropout"], name="res_dropout") return a
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.activation_type # We assume fixed vocab size for targets targets = tf.to_int32(features["targets"]) # Image preprocessing, reshape into a 1D sequence and shift right. length = hparams.img_len*hparams.img_len*hparams.num_channels targets = tf.reshape(targets, [hparams.batch_size, length]) shifted_targets = common_layers.shift_right_2d(targets) # Declare all the dimensions batch_dim = mtf.Dimension("batch", hparams.batch_size) def import_to_batch_by_length(x, name): return mtf.import_tf_tensor( mesh, x, mtf.Shape([batch_dim, self.length_dim]), name=name) targets = import_to_batch_by_length(targets, "targets") shifted_targets = import_to_batch_by_length( shifted_targets, "shifted_targets") extra_losses = [] # Create targets content and position embeddings. # Create embedding var for targets and positions and do a gather. targets_embedding_var = mtf.get_variable( mesh, "targets_embedding", mtf.Shape([self.targets_vocab_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=activation_dtype) x = mtf.gather(targets_embedding_var, shifted_targets, self.targets_vocab_dim) # Add positional embeddings x += mtf.reshape(self.create_positional_emb_2d(targets), [self.length_dim, self.model_dim]) # If conditional and input is given, add the input embedding to the target. # TODO(nikip): Verify conditional. if self.has_input and not hparams.unconditional: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = import_to_batch_by_length(inputs, "inputs") # Input embeddings inputs_embedding_var = mtf.layers.embedding( mesh, "input_embedding", mtf.Shape([self.inputs_vocab_dim, self.model_dim]), activation_dtype=activation_dtype) inputs_emb = mtf.gather( inputs_embedding_var, inputs, self.inputs_vocab_dim) x += inputs_emb # Image Transformer Decoder # [ self attention - ffn - residual + dropout] x n if hparams.attention_type == "local1d_spatial": decoder_output = local_attention1d_spatial_decoder( x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) elif hparams.attention_type == "local2d_spatial": decoder_output = local_attention2d_spatial_decoder( x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) elif hparams.attention_type == "local1d": decoder_output = local_attention1d_masked_decoder( x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) else: raise ValueError("Invalid attention type.") # Calculate the logits and loss. logits = mtf.layers.dense( decoder_output, self.outputs_vocab_dim, name="logits") # Need a reshape for logits logits = mtf.reshape( logits, mtf.Shape([batch_dim, self.length_dim, self.outputs_vocab_dim])) soft_targets = mtf.one_hot( targets, self.outputs_vocab_dim, dtype=activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.outputs_vocab_dim) loss = mtf.reduce_mean(loss) for l in extra_losses: loss += l # Reshape logits to original target shape. logits = mtf.reshape( logits, mtf.Shape([batch_dim, self.rows_dim, self.orig_cols_dim, self.channels_dim, self.outputs_vocab_dim])) return logits, loss
def model(mtf_features, other_features, params, mesh, variable_dtype, context=None): """A GPT style model implemented in mesh tensorflow.""" x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim = parse_inputs(mtf_features, other_features) if is_incremental_inference(context): # reshape inputs if in inference mode x = mtf.gather(x, context.position - 1, sequence_dim) x = mtf.reshape(x, [batch_dim]) use_axial_pos_emb = params["axial_pos_emb"] is not None if not use_axial_pos_emb: # Use standard position encoding wpe = mtf.get_variable(mesh, "wpe", mtf.Shape([embed_sequence_dim, embd_dim]), initializer=tf.random_normal_initializer(stddev=0.01), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) else: wpe = axial_positional_emb(embd_dim, mesh, params, variable_dtype) # Text encoding wte = mtf.get_variable(mesh, "wte", mtf.Shape([vocab_dim, embd_dim]), initializer=tf.random_normal_initializer(stddev=0.02), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) with tf.variable_scope("token_embd"): # Text embedding h = mtf.gather(wte, x, vocab_dim) if params["embed_dropout"] > 0 and params["mode"] == "train": h = mtf.dropout(h, rate=params["embed_dropout"], name="wte_dropout") with tf.variable_scope("pos_embd"): # Positional embedding position_indices = mtf.range(mesh, sequence_dim, tf.int64) if not is_incremental_inference(context) else ( context.position - 1) pos_emb = mtf.gather(wpe, position_indices, wpe.shape[0]) if params["embed_dropout"] > 0 and params["mode"] == "train": pos_emb = mtf.dropout(pos_emb, rate=params["embed_dropout"], name="wte_dropout") h += pos_emb aux_losses = 0 # instantiate auxiliary losses (for MOE models) for layer in range(params["n_layer"]): # attn blocks share_parameters = exists(params["share_parameters"]) and params["share_parameters"] == True block_scope = f"h{layer}" if not share_parameters else "" block_fn = block(params=params, scope=block_scope, layer_num=layer, bias=other_features["attn_bias"], sequence_dim=sequence_dim, memory_length_dim=other_features["memory_length_dim"], variable_dtype=variable_dtype, context=context) # If true and in train mode, enable gradient checkpointing recompute_grad = params["recompute_grad"] and (params["mode"] == "train") == True h, loss = block_fn(h) if not recompute_grad else mtf.recompute_grad(block_fn, [h]) aux_losses += loss no_weight_tie_emb = params["no_weight_tie"] == True if no_weight_tie_emb: with tf.variable_scope("wte_final_linear"): logits = linear(h, "linear_out", vocab_dim, variable_dtype=variable_dtype, params=params) else: # Layer normalize & affine transform h = layer_norm(h, "ln_f", variable_dtype=variable_dtype) seq_dim = sequence_dim if not is_incremental_inference(context) else mtf.Dimension("sequence", 1) with tf.variable_scope("wte_final_einsum"): # Equivalent to tf.matmul logits = mtf.einsum([h, wte], output_shape=[batch_dim, seq_dim, vocab_dim]) if params["mode"] in ["train", "eval"]: labels = mtf_features["labels"] z_loss = params.get("z_loss", 1e-4) # an auxiliary loss used to stabilize mtf xentropy # Go to full precision for the logits logits = mtf.cast(logits, tf.float32) use_entmax_loss = params.get("entmax_loss", False) loss_fn = mtf.layers.softmax_cross_entropy_with_logits if not use_entmax_loss else entmax_cross_entropy_with_logits with tf.variable_scope("xentropy_final"): loss_batch = loss_fn(logits=logits, targets=labels, vocab_dim=logits.shape[-1], z_loss=z_loss) # For non-autoregressive models (masked language modeling training) # Make sure labels with padding tokens are not counted in the loss if not params["causal"]: padding_id = params.get("padding_id", 0) loss_batch = mtf.where(mtf.not_equal(labels, padding_id), loss_batch, mtf.zeros_like(loss_batch)) with tf.variable_scope("reduce_mean_final"): loss = mtf.reduce_mean(loss_batch) loss += aux_losses # Add on auxiliary losses (currently only used for MoE) loss /= params["num_microbatches"] # Convert to train dtype loss = mtf.cast(loss, variable_dtype.slice_dtype) else: loss = None loss_batch = None # Cast back to checkpoint dtype logits = mtf.cast(logits, variable_dtype.master_dtype) return logits, loss, loss_batch
def recon_model(mesh, data, bparams, ipkerror, R0, x0, nc=FLAGS.nc, bs=FLAGS.box_size, batch_size=FLAGS.batch_size, a0=FLAGS.a0, a=FLAGS.af, nsteps=FLAGS.nsteps, dtype=tf.float32): """ Prototype of function computing LPT deplacement. Returns output tensorflow and mesh tensorflow tensors """ b1, b2, bs2 = bparams kerror, perror = ipkerror[0].astype(np.float32), ipkerror[1].astype( np.float32) if dtype == tf.float32: npdtype = "float32" cdtype = tf.complex64 elif dtype == tf.float64: npdtype = "float64" cdtype = tf.complex128 print("Dtype : ", dtype, npdtype) # Compute a few things first, using simple tensorflow kny = 1 * np.pi * nc / bs R1, R2 = 3., 3 * 1.2 stages = np.linspace(a0, a, nsteps, endpoint=True) # Define the named dimensions # Parameters of the small scales decomposition n_block_x = FLAGS.nx n_block_y = FLAGS.ny n_block_z = 1 halo_size = FLAGS.hsize if halo_size >= 0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z): new_size = int(0.5 * min(nc // n_block_x, nc // n_block_y, nc // n_block_z)) print('WARNING: REDUCING HALO SIZE from %d to %d' % (halo_size, new_size)) halo_size = new_size # Parameters of the large scales decomposition scalar = mtf.Dimension("scalar", 1) fx_dim = mtf.Dimension("nx", nc) fy_dim = mtf.Dimension("ny", nc) fz_dim = mtf.Dimension("nz", nc) tfx_dim = mtf.Dimension("tx", nc) tfy_dim = mtf.Dimension("ty", nc) tfz_dim = mtf.Dimension("tz", nc) tx_dim = mtf.Dimension("tx_lr", nc) ty_dim = mtf.Dimension("ty_lr", nc) tz_dim = mtf.Dimension("tz_lr", nc) nx_dim = mtf.Dimension('nx_block', n_block_x) ny_dim = mtf.Dimension('ny_block', n_block_y) nz_dim = mtf.Dimension('nz_block', n_block_z) sx_dim = mtf.Dimension('sx_block', nc // n_block_x) sy_dim = mtf.Dimension('sy_block', nc // n_block_y) sz_dim = mtf.Dimension('sz_block', nc // n_block_z) #k_dims = [tx_dim, ty_dim, tz_dim] batch_dim = mtf.Dimension("batch", batch_size) klin = np.loadtxt('..//data/Planck15_a1p00.txt').T[0] plin = np.loadtxt('..//data/Planck15_a1p00.txt').T[1] ipklin = iuspline(klin, plin) pk_dim = mtf.Dimension("npk", len(plin)) pk = mtf.import_tf_tensor(mesh, plin.astype(npdtype), shape=[pk_dim]) pke_dim = mtf.Dimension("epk", len(perror)) pkerror = mtf.import_tf_tensor(mesh, perror.astype(npdtype), shape=[pke_dim]) # Compute necessary Fourier kernels kvec = flowpm.kernels.fftk((nc, nc, nc), symmetric=False) kx = mtf.import_tf_tensor(mesh, kvec[0].squeeze().astype('float32'), shape=[tfx_dim]) ky = mtf.import_tf_tensor(mesh, kvec[1].squeeze().astype('float32'), shape=[tfy_dim]) kz = mtf.import_tf_tensor(mesh, kvec[2].squeeze().astype('float32'), shape=[tfz_dim]) kv = [ky, kz, kx] # kvec for low resolution grid kvec_lr = flowpm.kernels.fftk([nc, nc, nc], symmetric=False) kx_lr = mtf.import_tf_tensor(mesh, kvec_lr[0].squeeze().astype('float32'), shape=[tx_dim]) ky_lr = mtf.import_tf_tensor(mesh, kvec_lr[1].squeeze().astype('float32'), shape=[ty_dim]) kz_lr = mtf.import_tf_tensor(mesh, kvec_lr[2].squeeze().astype('float32'), shape=[tz_dim]) kv_lr = [ky_lr, kz_lr, kx_lr] shape = [batch_dim, fx_dim, fy_dim, fz_dim] lr_shape = [batch_dim, fx_dim, fy_dim, fz_dim] hr_shape = [batch_dim, nx_dim, ny_dim, nz_dim, sx_dim, sy_dim, sz_dim] part_shape = [batch_dim, fx_dim, fy_dim, fz_dim] splittables = lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3] # # Begin simulation if x0 is None: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.random_normal_initializer( mean=0.0, stddev=1, seed=None)) else: fieldvar = mtf.get_variable(mesh, 'linear', part_shape, initializer=tf.constant_initializer(x0)) state = mtfpm.lpt_init_single( fieldvar, a0, kv_lr, halo_size, lr_shape, hr_shape, part_shape[1:], antialias=True, ) final_state = mtfpm.nbody_single(state, stages, lr_shape, hr_shape, kv_lr, halo_size) # paint the field final_field = mtf.zeros(mesh, shape=part_shape) final_field = mcomp.cic_paint_fr(final_field, final_state, part_shape, hr_shape, halo_size, splittables, mesh) ## #Get the fields for bias hr_field = mcomp.fr_to_hr(final_field, hr_shape, halo_size, splittables, mesh) mstate = mpm.mtf_indices(hr_field.mesh, shape=part_shape[1:], dtype=tf.float32) X = mtf.einsum([mtf.ones(hr_field.mesh, [batch_dim]), mstate], output_shape=[batch_dim] + mstate.shape[:]) k_dims_pr = [d.shape[0] for d in kv] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] tfnc, tfbs = cswisef.float_to_mtf(nc * 1., mesh, scalar), cswisef.float_to_mtf( bs, mesh, scalar) # initc = fieldvar d0 = initc - mtf.reduce_mean(initc) # d2 = initc * initc d2 = d2 - mtf.reduce_mean(d2) # cfield = mesh_utils.r2c3d(d0, k_dims_pr, dtype=cdtype) shearfield = mtf.zeros(mesh, shape=part_shape) shearfield = shear(shearfield, cfield, kv, tfnc, tfbs) s2 = shearfield - mtf.reduce_mean(shearfield) dread = mcomp.cic_readout_fr(d0, [X], hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh) d2read = mcomp.cic_readout_fr(d2, [X], hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh) s2read = mcomp.cic_readout_fr(s2, [X], hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh) ed, ed2, es2 = mtf.zeros(mesh, shape=part_shape), mtf.zeros( mesh, shape=part_shape), mtf.zeros(mesh, shape=part_shape) ed = mcomp.cic_paint_fr(ed, final_state, output_shape=part_shape, hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh, weights=dread) ed2 = mcomp.cic_paint_fr(ed2, final_state, output_shape=part_shape, hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh, weights=d2read) es2 = mcomp.cic_paint_fr(es2, final_state, output_shape=part_shape, hr_shape=hr_shape, halo_size=halo_size, splittables=splittables, mesh=mesh, weights=s2read) model = ed * b1 + ed2 * b2 + es2 * bs2 mtfdata = mtf.import_tf_tensor(mesh, tf.convert_to_tensor(data), shape=shape) diff = model - mtfdata # Get prior k_dims_pr = [d.shape[0] for d in kv] k_dims_pr = [k_dims_pr[2], k_dims_pr[0], k_dims_pr[1]] cfield = mesh_utils.r2c3d(fieldvar, k_dims_pr, dtype=cdtype) def _cwise_prior(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid( x=kk, x_ref_min=1e-05, x_ref_max=1000.0, y_ref=pk, grid_regularizing_transform=tf.log) priormesh = tf.reshape(pkmesh, kshape) return tf.abs(kfield) / priormesh**0.5 cpfield = mtf.cwise(_cwise_prior, [cfield, pk] + kv, output_dtype=tf.float32) prior = mtf.reduce_sum(mtf.square(cpfield)) * bs**3 #* nc**3 # Total loss cdiff = mesh_utils.r2c3d(diff, k_dims_pr, dtype=cdtype) def _cwise_diff(kfield, pk, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = tf.sqrt((kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2) kshape = kk.shape kk = tf.reshape(kk, [-1]) pkmesh = tfp.math.interp_regular_1d_grid(x=kk, x_ref_min=kerror.min(), x_ref_max=kerror.max(), y_ref=pk) priormesh = tf.reshape(pkmesh, kshape) priormesh = tf.cast(priormesh**0.5, kfield.dtype) return kfield / priormesh cdiff = mtf.cwise(_cwise_diff, [cdiff, pkerror] + kv, output_dtype=cdtype) def _cwise_smooth(kfield, kx, ky, kz): kx = tf.reshape(kx, [-1, 1, 1]) ky = tf.reshape(ky, [1, -1, 1]) kz = tf.reshape(kz, [1, 1, -1]) kk = (kx / bs * nc)**2 + (ky / bs * nc)**2 + (kz / bs * nc)**2 wts = tf.cast(tf.exp(-kk * (R0 * bs / nc)**2), kfield.dtype) return kfield * wts cdiff = mtf.cwise(_cwise_smooth, [cdiff] + kv, output_dtype=cdtype) diff = mesh_utils.c2r3d(cdiff, diff.shape[-3:], dtype=dtype) chisq = mtf.reduce_sum(mtf.square(diff)) loss = chisq + prior fields = [fieldvar, final_field, model] metrics = [chisq, prior, loss] return fields, metrics, kv
def _layer_stack(self, x, layers, encoder_output=None, self_attention_mask=None, encdec_attention_mask=None, losses=None, step_num=None, encdec_tensors=None, states=None): """Encoder or decoder stack. Args: x: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim] layers: an list of strings encoder_output: an optional mtf.Tensor with shape [<batch_dims>, encoder_length_dim, model_dim] self_attention_mask: an optional mtf.Tensor with shape [batch, length_dim, memory_length_dim] containing values 0 or -inf. encdec_attention_mask: an optional mtf.Tensor with shape [batch, length_dim, encoder_length_dim] containing values 0 or -inf. losses: a list to be appended-to step_num: an optional mtf integer Scalar (used in incrmenental mode) encdec_tensors: an optional list of num_layers tuples, each of the form (q_var, o_var, k, v), (used in incremental mode) states: an optional list of Tensors (used in incremental mode) Returns: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim] Raises: ValueError: if hparams make no sense """ hparams = self._hparams is_incremental = (step_num is not None) mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN) is_training = mode == tf.estimator.ModeKeys.TRAIN def layer_prepostprocess_dropout(x): if is_incremental: return x return mtf.dropout( x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape(self.batch_dims + [self.model_dim])) num_layers = len(layers) num_layer_norms = num_layers + 1 layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms) layer_norm_combined_var = mtf.get_variable( x.mesh, "layer_norm_scale", mtf.Shape([layer_norms_dim, self.model_dim]), initializer=tf.ones_initializer(), activation_dtype=x.dtype) layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim) def normalize(x): scale = layer_norm_vars.pop(0) variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim) return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale if is_incremental: states = list(states) new_states = [] tf.logging.info("states = %s" % (states,)) for lnum, layer_type in enumerate(layers): with tf.variable_scope("%s_%d" % (layer_type, lnum)): if layer_type == "att": # Self attention layer if is_incremental: y, new_k, new_v = mtf.layers.multihead_self_attention_incremental( normalize(x), prev_k=states.pop(0), prev_v=states.pop(0), step_num=step_num, master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="att") new_states.append(new_k) new_states.append(new_v) x += y else: x += layer_prepostprocess_dropout( mtf.layers.multihead_attention( normalize(x), None, self_attention_mask, self.kv_dim, self.heads_dim, is_training, dropout=hparams.attention_dropout, dropout_broadcast_dims=[self.length_dim], master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="att")) elif layer_type == "enc_att": # Encoder-Decoder attention layer if is_incremental: # Encoder-Decoder attention layer q_var, o_var, k, v = encdec_tensors[lnum] x += mtf.layers.multihead_encdec_attention_incremental( normalize(x), q_var, o_var, k, v, encdec_attention_mask, name="enc_att") else: x += layer_prepostprocess_dropout( mtf.layers.multihead_attention( normalize(x), encoder_output, encdec_attention_mask, self.kv_dim, self.heads_dim, is_training, dropout=hparams.attention_dropout, dropout_broadcast_dims=[self.length_dim], master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="enc_att")) elif layer_type == "local_att": if is_incremental: y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental( normalize(x), prev_k=states.pop(0), prev_v=states.pop(0), step_num=step_num, master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="local_att") new_states.append(new_k) new_states.append(new_v) x += y else: x += layer_prepostprocess_dropout( mtf.layers.masked_local_attention_1d( normalize(x), self.kv_dim, self.heads_dim, is_training, window_size=hparams.local_attention_window_size, master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, length_per_split=mtf.tensor_dim_to_size_per_split( hparams.layout, hparams.mesh_shape, self.max_length_dim), name="local_att")) elif layer_type == "compressed_att": if is_incremental: raise ValueError("compressed_att incremental not implemented") else: x += layer_prepostprocess_dropout( mtf.layers.multihead_self_attention_memory_compressed( normalize(x), mask_right=True, compression_factor=hparams.compression_factor, kv_channels=self.kv_dim, heads=self.heads_dim, is_training=is_training, dropout=hparams.attention_dropout, dropout_broadcast_dims=[self.length_dim], master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="compressed_att")) else: if is_incremental: # insert length dimension. x_shape = x.shape shape_with_length = mtf.Shape( x_shape.dims[:-1] + [mtf.Dimension("length", 1)] + x_shape.dims[-1:]) x = mtf.reshape(x, shape_with_length) # ffn layer x += layer_prepostprocess_dropout( self._feedforward_layer(normalize(x), layer_type, losses=losses)) if is_incremental: # remove length dimension x = mtf.reshape(x, x_shape) x = layer_prepostprocess_dropout(normalize(x)) assert not layer_norm_vars if is_incremental: return x, new_states else: return x
def __init__(self, mesh, query_input_dim, memory_input_dim, output_dim, key_dim, value_dim, query_heads_dims, memory_heads_dims, variable_dtype, shared_kv=False, no_query=False, combine_dims=True, ensemble_dim=None, keep_query_heads_dims=False, fold_scaling_into_initializer=True, context=None, experts_hparams=None, expert_computation="qkv"): super(ExpertsAttentionParams, self).__init__( mesh=mesh, query_input_dim=query_input_dim, memory_input_dim=memory_input_dim, output_dim=output_dim, key_dim=key_dim, value_dim=value_dim, query_heads_dims=query_heads_dims, memory_heads_dims=memory_heads_dims, variable_dtype=variable_dtype, shared_kv=shared_kv, no_query=no_query, combine_dims=combine_dims, ensemble_dim=ensemble_dim, keep_query_heads_dims=keep_query_heads_dims, fold_scaling_into_initializer=fold_scaling_into_initializer, make_attention_vars=False) self.context = context self.expert_computation = expert_computation # Unless we want to compute both q and kv, we can use the normal MoE # settings. if expert_computation == "qkv": experts_attention_compute_qkv = True elif expert_computation in ["q", "kv"]: experts_attention_compute_qkv = False if expert_computation == "q": # Always assume shared_kv. self.wkv = mtf.get_variable( self.mesh, "kv", self.k_shape, initializer=tf.random_normal_initializer( stddev=self.memory_input_dim.size ** -0.5), dtype=self.variable_dtype) else: # Computing kv with experts. self.wq = mtf.get_variable( self.mesh, "q", self.q_shape, initializer=tf.random_normal_initializer( stddev=self.query_input_dim.size ** -0.5), dtype=self.variable_dtype) else: raise ValueError("Invalid expert computation mode: {}".format( expert_computation)) # ExpertsAttention, for simplicitly, asserts that combine_dims is True, and # for efficiency, that shared_kv is True. if not self.combine_dims: raise ValueError("combine_dims must be True for ExpertsAttention.") if not self.shared_kv: raise ValueError("shared_kv must be True for ExpertsAttention.") if mtf.layers.unit_scaling_convention(): raise NotImplementedError # Now replace "heads" dim with the "d_model" name to avoid conflicts when # we want to partition both "experts_hidden" and "heads". moe_output_dims = mtf.Dimension("d_model", self.q_shape[-1].size) tf.logging.info("ExpertsAttention moe_hidden_size: {}".format( experts_hparams.hidden_size)) tf.logging.info("moe_output_dims: {}".format(moe_output_dims)) self.moe_layer = mtf.transformer.moe.MoE1D( moe_gating=experts_hparams.moe_gating, num_experts=experts_hparams.num_experts, loss_coef=experts_hparams.loss_coef, group_size=experts_hparams.group_size, min_expert_capacity=experts_hparams.min_expert_capacity, capacity_factor_train=experts_hparams.capacity_factor_train, capacity_factor_eval=experts_hparams.capacity_factor_eval, switch_policy_train=experts_hparams.switch_policy_train, switch_policy_eval=experts_hparams.switch_policy_eval, switch_dropout=experts_hparams.switch_dropout, switch_temperature=experts_hparams.switch_temperature, switch_jitter=experts_hparams.switch_jitter, ntlb_top_k=experts_hparams.ntlb_top_k, hidden_size=experts_hparams.hidden_size, output_dim=moe_output_dims, use_experts_attention=experts_attention_compute_qkv, activation=experts_hparams.activation, z_loss=experts_hparams.z_loss)
def hybrid_attention(q, k, v, context, memory_length_dim, key_dim, value_dim, bias=None, dropout_rate=0.0, dropout_broadcast_dims=None, extra_logit=None): """Dot-product attention - doesn't use positional dimensions. key_dim is a Dimension representing the channels in the queries and keys value_dim is a Dimension representing the channels in values memory_length_dim is a Dimension representing the different key/value pairs. Dimensions of q: other_query_dims + {key_dim} Dimensions of k: other_memory_dims + {memory_length_dim, key_dim} Dimensions of v: other_memory_dims + {memory_length_dim, value_dim} other_memory_dims is a subset of other_query_dims Typically, other_query_dims={batch, heads, length} Typically, other_memory_dims={batch, heads} Args: q: a Tensor k: a Tensor v: a Tensor context: context of the attention layer. memory_length_dim: a Dimension key_dim: a Dimension value_dim: a Dimension bias: a Tensor to be added into the attention logits. dropout_rate: a float. dropout_broadcast_dims: an optional list of mtf.Dimension extra_logit: an optional scalar or tensor Returns: Tensor with shape q.shape - key_dim + value_dim """ logits = mtf.layers.us_einsum([q, k], reduced_dims=[key_dim]) if bias is not None: logits += bias query_length_dim = mtf.Dimension("length", memory_length_dim.size) doubly_coeff = mtf.get_variable( context.mesh, "doubly_coeff", [], initializer=tf.constant_initializer(0.5), dtype=context.variable_dtype) doubly_coeff = mtf.maximum(mtf.minimum(doubly_coeff, 1.), 0.) upper_weights = mtf.softmax( logits, memory_length_dim, extra_logit=extra_logit) lower_log_weights = mtf.log_softmax( logits, query_length_dim, extra_logit=extra_logit) doubly_weights = mtf.softmax( lower_log_weights, memory_length_dim, extra_logit=extra_logit) weights = doubly_coeff * doubly_weights + (1. - doubly_coeff) * upper_weights weights = mtf.dropout( weights, context.train, 1.0 - dropout_rate, noise_shape=weights.shape - dropout_broadcast_dims) outputs_shape = q.shape - key_dim + value_dim outputs = mtf.einsum([weights, v], outputs_shape) return outputs
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.set_activation_type() is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN # Declare all the dimensions batch_dim = mtf.Dimension("batch", hparams.batch_size) hidden_dim = mtf.Dimension("hidden", hparams.hidden_size) filter_h_dim = mtf.Dimension("filter_height", 7) filter_w_dim = mtf.Dimension("filter_width", 7) filters = mtf.Dimension("filters", hparams.filter_sizes[0]) rows_dim = mtf.Dimension("rows_size", 32) cols_dim = mtf.Dimension("cols_size", 96) row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks) col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 1) inputs = features["inputs"] x = mtf.import_tf_tensor( mesh, tf.reshape(inputs, [ hparams.batch_size, hparams.row_blocks, hparams.rows_size // hparams.row_blocks, hparams.col_blocks, hparams.num_channels * hparams.cols_size // hparams.col_blocks, 1 ]), mtf.Shape([ batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, one_channel_dim ])) x = mtf.transpose(x, [ batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, one_channel_dim ]) x = mtf.to_float(x) initial_filters = mtf.get_variable( mesh, "init_filters", mtf.Shape([filter_h_dim, filter_w_dim, one_channel_dim, filters])) x = mtf.conv2d_with_blocks(x, initial_filters, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) x = batch_norm_relu(x, is_training) # Conv blocks # [ self attention - ffn - residual + dropout] x n for layer in range(hparams.num_layers): layer_name = "block_layer_%d" % layer with tf.variable_scope(layer_name): # Residual block layer x = block_layer(inputs=x, filters=hparams.filter_sizes[0], blocks=hparams.layer_sizes[0], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer1", row_blocks_dim=None, col_blocks_dim=None) x = block_layer(inputs=x, filters=hparams.filter_sizes[1], blocks=hparams.layer_sizes[1], strides=[1, 2, 2, 1], is_training=is_training, name="block_layer2", row_blocks_dim=None, col_blocks_dim=None) x = block_layer(inputs=x, filters=hparams.filter_sizes[2], blocks=hparams.layer_sizes[2], strides=[1, 2, 2, 1], is_training=is_training, name="block_layer3", row_blocks_dim=None, col_blocks_dim=None) # Calculate the logits and loss. out = x outputs = mtf.layers.dense(out, hidden_dim, reduced_dims=out.shape.dims[-5:], activation=mtf.relu, name="dense") # We assume fixed vocab size for targets labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3]) labels = mtf.import_tf_tensor(mesh, tf.reshape(labels, [hparams.batch_size]), mtf.Shape([batch_dim])) logits = mtf.layers.dense(outputs, classes_dim, name="logits") soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, classes_dim) # Reshape logits so it doesn't break inside t2t. logits = mtf.reshape( logits, mtf.Shape([batch_dim, one_channel_dim, classes_dim])) loss = mtf.reduce_mean(loss) return logits, loss