def benchmark_model(mesh): """ Initializes a 3D volume with random noise, and execute a forward FFT """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) x_dim = mtf.Dimension("nx", FLAGS.cube_size) y_dim = mtf.Dimension("ny", FLAGS.cube_size) z_dim = mtf.Dimension("nz", FLAGS.cube_size) tx_dim = mtf.Dimension("tnx", FLAGS.cube_size) ty_dim = mtf.Dimension("tny", FLAGS.cube_size) tz_dim = mtf.Dimension("tnz", FLAGS.cube_size) # Create field field = mtf.random_uniform(mesh, [batch_dim, x_dim, y_dim, z_dim]) # Apply FFT fft_field = mpm.fft3d(mtf.cast(field, tf.complex64), [tx_dim, ty_dim, tz_dim]) # Inverse FFT rfield = mtf.cast(mpm.ifft3d(fft_field, [x_dim, y_dim, z_dim]), tf.float32) # Compute errors err = mtf.reduce_max(mtf.abs(field - rfield)) return err
def benchmark_model(mesh): """ Initializes a 3D volume with random noise, and execute a forward FFT """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) x_dim = mtf.Dimension("nx", FLAGS.cube_size) y_dim = mtf.Dimension("ny", FLAGS.cube_size) z_dim = mtf.Dimension("nz", FLAGS.cube_size) tx_dim = mtf.Dimension("tnx", FLAGS.cube_size) ty_dim = mtf.Dimension("tny", FLAGS.cube_size) tz_dim = mtf.Dimension("tnz", FLAGS.cube_size) # Create field field = mtf.random_normal(mesh, [batch_dim, x_dim, y_dim, z_dim]) input_field = field field = mtf.cast(field, tf.complex64) err = 0 # Performs several back and forth FFTs in the same session for i in range(FLAGS.n_ffts): # Apply FFT fft_field = mpm.fft3d(field, [tx_dim, ty_dim, tz_dim]) # Inverse FFT field = mpm.ifft3d(fft_field * 1, [x_dim, y_dim, z_dim]) err += mtf.reduce_max(mtf.abs(mtf.cast(field, tf.float32) - input_field)) field = mtf.cast(field, tf.float32) # Compute errors err += mtf.reduce_max(mtf.abs(field - input_field)) return err
def compute_mask(self, context, memory_position): """Compute attention mask. Args: context: a transformer.Context memory_position: an int32 tensor containing memory_length dimension. Returns: a Tensor or None """ masks = [] min_relative_position = self.min_relative_position(context) max_relative_position = self.max_relative_position(context) if max_relative_position is not None or min_relative_position is not None: relative_position = memory_position - context.position if min_relative_position is not None: illegal = mtf.less(relative_position, min_relative_position) masks.append(mtf.cast(illegal, context.activation_dtype) * -1e9) if max_relative_position is not None: illegal = mtf.greater(relative_position, max_relative_position) masks.append(mtf.cast(illegal, context.activation_dtype) * -1e9) if (context.sequence_id is not None and isinstance(context.sequence_id, mtf.Tensor) and context.length_dim in context.sequence_id.shape): masks.append(mtf.cast( mtf.not_equal( context.sequence_id, self.rename_length_to_memory_length( context.sequence_id, context)), context.activation_dtype) * -1e9) return mtf.add_n(masks) if masks else None
def get_project_to_cluster_length(self, cluster_mask, dtype): """Returns projection from length dim to the shorter cluster length dim.""" seq_length_dim = cluster_mask.shape.get_dim_by_name("length") cluster_length_dim = self.get_cluster_length_dim(seq_length_dim) return mtf.cast(cluster_mask, dtype) * mtf.one_hot( mtf.cumsum(mtf.cast(cluster_mask, tf.int32), seq_length_dim) - 1, output_dim=cluster_length_dim, dtype=dtype)
def deep(x, mask, float16=None): x = mtf.einsum([x, mask], output_shape=x.shape.dims, name='deep_mul') logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) # 使用仿照mindspore中使用fp16来计算下面的dense x = mtf.cast(x, dtype=tf.float16) x = mtf.layers.dense(x, mtf.Dimension(name='dense_dim0', size=1024), name="deep-dense-0", reduced_dims=x.shape.dims[-2:], activation=mtf.relu, variable_dtype=mtf.VariableDType( tf.float16, tf.float16, tf.float16)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) x = mtf.layers.dense(x, mtf.Dimension(name='dense_dim1', size=512), name="deep-dense-1", activation=mtf.relu, variable_dtype=mtf.VariableDType( tf.float16, tf.float16, tf.float16)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) x = mtf.layers.dense(x, mtf.Dimension(name='dense_dim2', size=256), name="deep-dense-2", activation=mtf.relu, variable_dtype=mtf.VariableDType( tf.float16, tf.float16, tf.float16)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) x = mtf.layers.dense(x, mtf.Dimension(name='dense_dim3', size=128), name="deep-dense-3", activation=mtf.relu, variable_dtype=mtf.VariableDType( tf.float16, tf.float16, tf.float16)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) x = mtf.layers.dense(x, mtf.Dimension(name='dense_dim4', size=1), name="deep-dense-4", variable_dtype=mtf.VariableDType( tf.float16, tf.float16, tf.float16)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) if float16: pass else: x = mtf.cast(x, dtype=tf.float32) return x
def model_backbone(features, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 32*32] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ id_hldr, wt_hldr = features batch_dim = mtf.Dimension("batch", args_opt.batch_size) field_dim = mtf.Dimension("field", size=39) vocab_dim = mtf.Dimension("vocab_size", 200000) embed_dim = mtf.Dimension("embed_size", 80) outdim = mtf.Dimension("outdim", 1) id_hldr = mtf.import_tf_tensor( mesh, tf.reshape(id_hldr, [args_opt.batch_size, field_dim.size]), mtf.Shape([batch_dim, field_dim])) wt_hldr = mtf.import_tf_tensor( mesh, tf.reshape(wt_hldr, [args_opt.batch_size, field_dim.size]), mtf.Shape([batch_dim, field_dim])) if args_opt.fp16: float16 = mtf.VariableDType(tf.float16, tf.float16, tf.float16) # id_hldr=mtf.cast(id_hldr,dtype=tf.int32) wt_hldr = mtf.cast(wt_hldr, dtype=tf.float16) else: float16 = None logits, embedding_table = network[args_opt.model](id_hldr, wt_hldr, vocab_dim, embed_dim, outdim, float16=float16) logits = mtf.cast(logits, dtype=tf.float32) embedding_table = mtf.cast(embedding_table, dtype=tf.float32) if labels is None: wide_loss = None deep_loss = None else: labels = mtf.import_tf_tensor( mesh, tf.reshape(labels, [args_opt.batch_size]), mtf.Shape([batch_dim])) wide_loss = mtf.layers.sigmoid_cross_entropy_with_logits( logits, labels) deep_loss = mtf.reduce_mean(mtf.square(embedding_table)) / 2 deep_loss = mtf.reduce_mean(wide_loss) + 8e-5 * deep_loss wide_loss = mtf.reduce_mean(wide_loss) return logits, wide_loss + deep_loss
def call(self, context, x, losses=None): """Call the layer.""" wq, wk, wv, wo = mtf.layers.multihead_attention_params( context.mesh, self.heads_dim, context.model_dim, self.kv_dim, context.variable_dtype) memory_length = mtf.Dimension("memory_length", context.length_dim.size) q = mtf.einsum([x, wq], reduced_dims=[context.model_dim]) if context.mode == "incremental": m = x else: m = mtf.rename_dimension(x, context.length_dim.name, "memory_length") k = mtf.einsum([m, wk], reduced_dims=[context.model_dim]) v = mtf.einsum([m, wv], reduced_dims=[context.model_dim]) if context.mode == "incremental": old_k, old_v = context.get_states(2) one_hot = mtf.one_hot(context.position, memory_length, dtype=context.activation_dtype) inv_one_hot = 1.0 - one_hot k = old_k * inv_one_hot + k * one_hot v = old_v * inv_one_hot + v * one_hot if context.mode == "incremental" or context.mode == "first_part": context.record_new_states([k, v]) masks = [] if context.autoregressive: masks.append( mtf.cast( mtf.less( context.position, mtf.range(context.mesh, memory_length, dtype=tf.int32)), context.activation_dtype) * -1e9) if (context.sequence_id is not None and isinstance(context.sequence_id, mtf.Tensor) and context.length_dim in context.sequence_id.shape): masks.append( mtf.cast( mtf.not_equal( context.sequence_id, mtf.layers.rename_length_to_memory_length( context.sequence_id)), context.activation_dtype) * -1e9) mask = mtf.add_n(masks) if masks else None o = mtf.layers.dot_product_attention_v2( q, k, v, memory_length, self.kv_dim, self.kv_dim, mask, self.dropout_rate if context.train else 0.0, [context.length_dim]) return mtf.einsum([o, wo], x.shape, reduced_dims=[self.heads_dim, self.kv_dim])
def add_position_timing_signal_func(self, context, x, step): """Add n-dimensional embedding as the position (horizontal) timing signal. Args: context: mtf context x: a tensor with shape [batch, length, depth] step: step Returns: a Tensor with the same shape as x. """ if not self.position_start_index: index = 0 elif self.position_start_index == "random": # Shift all positions randomly # TODO(dehghani): What would be reasonable for max number of shift? index = mtf.random_uniform(context.mesh, [], maxval=x.shape.dims[1].size, dtype=tf.int32) elif self.position_start_index == "step": # Shift positions based on the step if self.recurrence_type == "act": num_steps = self.act_max_steps else: num_steps = self.num_rec_steps index = mtf.cast(x.shape.dims[1].size * step / num_steps, dtype=tf.int32) length = context.length_dim channels = context.model.model_dim signal = self.get_timing_signal_1d(context, length, channels, start_index=index) if self.add_or_concat_timing_signal == "add": x_with_timing = x + mtf.cast(signal, x.dtype) # Unimplemented if self.add_or_concat_timing_signal == "concat": batch_dim = x.shape.dims[0] out_shape = mtf.Shape([batch_dim] + signal.shape.dims[1:]) signal_tiled = mtf.broadcast(signal, out_shape) x_with_timing = mtf.concat( (x, signal_tiled), concat_dim_name=signal_tiled.dimension_names[-1]) return x_with_timing
def sample_categorical(x, dim=None): dim = x.shape[-1] if dim is None else dim cdf = mtf.cumsum(x, dim) rand_uniform = mtf.random_uniform(x.mesh, x.shape - dim, minval=0, maxval=1) mask = mtf.cast(mtf.greater(cdf, rand_uniform), tf.int32) return mtf.argmax(mask, dim)
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) rows_dim = mtf.Dimension("rows_size", image_height) cols_dim = mtf.Dimension("cols_size", image_width) channel_dim = mtf.Dimension("image_channel", num_channels) classes_dim = mtf.Dimension(name='classesnum',size=classesnum) x = mtf.import_tf_tensor( mesh, tf.reshape(image, [FLAGS.batch_size, image_height, image_width, num_channels]), mtf.Shape( [batch_dim, rows_dim, cols_dim, channel_dim])) # x = mtf.transpose(x, [batch_dim, rows_dim, cols_dim, channel_dim]) # print(x.shape) logits = VGG(x, classes_dim=classes_dim,depth=depth) logits = mtf.cast(logits,dtype=tf.float32) if labels is None: loss = None else: labels = mtf.import_tf_tensor( mesh, tf.reshape(labels, [FLAGS.batch_size]), mtf.Shape([batch_dim])) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def c2r3d(cfield, dims, norm=None, dtype=tf.float32, name=None): """ Converts a complex Fourier domain field to a real field Parameters: ----------- cfield: tensor (batch_size, nc, nc, nc) Complex 3D real field norm: float Normalization factor dtype: tf.dtype Type of output tensor Return: ------- rfield: tensor (batch_size, nc, nc, nc) Real valued field """ x_dim, y_dim, z_dim = cfield.shape[-3:] if norm is None: norm = mtf.constant(cfield.mesh, x_dim.size * y_dim.size * z_dim.size) rfield = mtf.cast(mesh_ops.ifft3d(cfield, dims), dtype) * norm return rfield
def r2c3d(rfield, k_dims, norm=None, dtype=tf.complex64): """ Converts a real field to its complex Fourier Transform Parameters: ----------- rfield: tensor (batch_size, nc, nc, nc) Input 3D real field norm: float Normalization factor dtype: tf.dtype Type of output tensor Return: ------- cfield: tensor (batch_size, nc, nc, nc) Complex field """ x_dim, y_dim, z_dim = rfield.shape[-3:] if norm is None: norm = mtf.constant(rfield.mesh, x_dim.size * y_dim.size * z_dim.size) cfield = mesh_ops.fft3d(mtf.cast(rfield / norm, dtype), k_dims) return cfield
def forward(self, features, return_loss=True, return_logits=False): inputs = features["tokens"] tokens = self.positional_embedding(self.embedding(inputs, "embedding"), "positional_embedding") mask = self.get_attn_mask(tokens.mesh, tokens.shape[1], self.dimensions["memory_len_dim"]) out = self.transformer(tokens, mask=mask) logits = self.to_logits(out) if not return_loss: return logits labels = pad(inputs, [0, 1], dim_name="total_seq_dim", pad_value=self.eos_token_id) indices = mtf.range(labels.mesh, mtf.Dimension("range", labels.shape[1].size - 1), tf.int32, name="labels_indices") + 1 labels = mtf.gather(labels, indices, dim=labels.shape[1]) labels = mtf.rename_dimension(labels, "range", "total_seq_dim") loss, loss_batch = self._loss(logits, labels) if return_logits and return_loss: # Cast back to checkpoint dtype logits = mtf.cast(logits, self.variable_dtype.master_dtype) return loss, loss_batch, logits return loss, loss_batch
def to_logits(self, x): with tf.variable_scope("to_logits"): logits = self.linear(self.layer_norm(x), self.dimensions["final_vocab_dim"], name="linear_out") # Go to full precision for the logits return mtf.cast(logits, tf.float32)
def toy_model(features, mesh): """A toy model implemented by mesh tensorlfow.""" batch_dim = mtf.Dimension('batch', FLAGS.batch_size) io_dim = mtf.Dimension('io', FLAGS.io_size) master_dtype = tf.as_dtype(FLAGS.master_dtype) slice_dtype = tf.as_dtype(FLAGS.slice_dtype) activation_dtype = tf.as_dtype(FLAGS.activation_dtype) x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim])) x = mtf.cast(x, activation_dtype) h = x for lnum in xrange(1, FLAGS.num_hidden_layers + 2): if lnum + 1 == FLAGS.num_hidden_layers + 2: # output layer dim = io_dim elif lnum % 2 == 0: dim = mtf.Dimension('hidden_even', FLAGS.hidden_size) else: dim = mtf.Dimension('hidden_odd', FLAGS.hidden_size) h = mtf.layers.dense( h, dim, use_bias=False, master_dtype=master_dtype, slice_dtype=slice_dtype, name='layer_%d' % lnum) y = h loss = mtf.reduce_mean(mtf.square(y - x)) return y, loss
def sample(self, features, mesh): hparams = self._hparams model = self.model() def import_feature(key): return self._import_feature(features, mesh, key) if self.autoregressive: # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = import_feature("inputs") if partial_targets is None: partial_targets = import_feature("targets") if partial_targets: partial_targets *= mtf.cast(mtf.not_equal(partial_targets, 1), partial_targets.dtype) else: ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) partial_targets = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) if hparams.beam_size > 1: raise NotImplementedError( "Beam search not implemented for unitransformer.") ret = model.sample_autoregressive( partial_targets, temperature=hparams.sampling_temp, variable_dtype=self.variable_dtype) return self.combine_batch_dims(ret) else: raise ValueError( "Don't know how to sample from non-autoregressive unitransformer" )
def call(self, context, x): """Call the layer stack.""" if isinstance(context.sequence_id, mtf.Tensor): # We use this mask to zero out the padding regions at each layer. # This "fixes" a bug where extreme values leak from the padding into the # non-padding regions. # TODO(noam): understand this better and make a more principled fix. mask = mtf.cast(mtf.not_equal(context.sequence_id, 0), context.activation_dtype) else: mask = None x = self._dropout(context, x) context.layer_outputs.append(x) if self.mix_with_transformer_before_ut: for _ in range(self.num_vanilla_transformer_layers): x = self.vanilla_transformer_layer(context, x, mask) # Call a ACT layer if self.recurrence_type == "act": x = self.act_layer(context, x, mask) elif self.recurrence_type == "basic": x = self.ut_basic(context, x, mask) elif self.recurrence_type == "highway": layer_inputs = (x, x, x) x = self.ut_highway(context, layer_inputs, mask) if self.mix_with_transformer_after_ut: for _ in range(self.num_vanilla_transformer_layers): x = self.vanilla_transformer_layer(context, x, mask) x = self._layer_norm(context, x, name="final_layer_norm") x = self._dropout(context, x) if mask: x *= mask context.layer_outputs.append(x) return x
def add_step_timing_signal_func(self, context, x, step): """Add n-dimensional embedding as the step (vertical) timing signal. Args: context: mtf context x: a tensor with shape [batch, length, depth] step: step Returns: a Tensor with the same shape as x. """ if self.recurrence_type == "act": num_steps = self.act_max_steps else: num_steps = self.num_rec_steps channels = x.shape.dims[-1] if self.step_timing_signal_type == "learned": signal = self.get_layer_timing_signal_learned_1d( context, channels, step, num_steps) elif self.step_timing_signal_type == "sinusoid": signal = self.get_layer_timing_signal_sinusoid_1d( context, channels, step, num_steps) if self.add_or_concat_timing_signal == "add": x_with_timing = x + mtf.cast(signal, x.dtype) elif self.add_or_concat_timing_signal == "concat": batch_dim = x.shape.dims[0] out_shape = mtf.Shape([batch_dim] + x.shape.dims[1:]) signal_tiled = mtf.broadcast(signal, out_shape) x_with_timing = mtf.concat( (x, signal_tiled), concat_dim_name=signal_tiled.dimension_names[-1]) return x_with_timing
def call(self, context, x): """Call the layer stack.""" if isinstance(context.sequence_id, mtf.Tensor): # We use this mask to zero out the padding regions at each layer. # This "fixes" a bug where extreme values leak from the padding into the # non-padding regions. # TODO(noam): undertand this better and make a more principled fix. mask = mtf.cast(mtf.not_equal(context.sequence_id, 0), context.activation_dtype) else: mask = None x = self._dropout(context, x) if context.layer_outputs is not None: context.layer_outputs.append(x) for lnum, layer in enumerate(self._layers): with tf.variable_scope("layer_%03d" % lnum): norm_x = self._layer_norm(context, (x * mask) if mask else x) with tf.variable_scope(layer.__class__.__name__): y = layer.call(context, norm_x) if y.shape != x.shape: raise ValueError( "Layer %s returned misshaped output x=%s y=%s" % (layer.__class__.__name__, x, y)) x += self._dropout(context, y) if context.layer_outputs is not None and lnum != len( self._layers) - 1: context.layer_outputs.append(x) context.layer_index += 1 x = self._layer_norm(context, x, name="final_layer_norm") x = self._dropout(context, x) if mask: x *= mask if context.layer_outputs is not None: context.layer_outputs.append(x) return x
def call(self, context, x, losses=None): """Call the layer.""" params = mtf.layers.multihead_attention_params(context.mesh, self.heads_dim, context.model_dim, self.kv_dim, context.variable_dtype) if context.mode == "incremental": prev_k, prev_v = context.get_states(2) y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental( x, prev_k, prev_v, context.position, params=params) context.record_new_states([new_k, new_v]) return y else: kv = [] y = mtf.layers.masked_local_attention_1d(x, self.kv_dim, self.heads_dim, self.window_size, params=params, return_kv=kv) if context.mode == "first_part": k = kv[0] v = kv[1] window_dim = mtf.Dimension("window", self.window_size) mesh = k.mesh window_pos = mtf.range(mesh, window_dim, tf.int32) pos = mtf.range(mesh, context.length_dim, tf.int32) select_recent = mtf.cast( mtf.equal(window_pos, mtf.mod(pos, self.window_size)), k.dtype) select_recent *= mtf.cast( mtf.less(pos, context.initial_position), k.dtype) select_recent *= mtf.cast( mtf.greater_equal( pos, context.initial_position - self.window_size), k.dtype) state_shape = k.shape.dims[:-2] + [window_dim, self.kv_dim] k_state = mtf.einsum([k, select_recent], output_shape=state_shape, reduced_dims=[context.length_dim]) v_state = mtf.einsum([v, select_recent], output_shape=state_shape, reduced_dims=[context.length_dim]) context.new_states.extend([k_state, v_state]) return y
def get_attn_mask(self, mesh, nd, ns): if not exists(self.attn_mask): i = mtf.range(mesh, nd, tf.int32) + ns.size - nd.size j = mtf.range(mesh, ns, tf.int32) i, j = map(lambda t: mtf.broadcast(t, [nd, ns]), (i, j)) self.attn_mask = mtf.cast(mtf.less( i, j), self.variable_dtype.activation_dtype) * -1e10 return self.attn_mask
def nonpadding(self): """Tensor with zeros in padding positions and ones elsewhere.""" if self.sequence_id is None: return None if self.sequence_id == 1: return 1 else: return mtf.cast( mtf.not_equal(self.sequence_id, 0), self.activation_dtype)
def compute_loss(self, decoder: transformer.Unitransformer, hidden: mtf.Tensor, targets: mtf.Tensor, context: transformer.Context) -> mtf.Tensor: """Returns the loss without computing a softmax over the entire vocab.""" loss = 0 tail_cluster_masks = [] for cluster in self._tail_clusters: cluster_mask = cluster.get_cluster_mask(targets) tail_cluster_masks.append(cluster_mask) if cluster.length_projection_factor == 1: targets_in_cluster = mtf.where(cluster_mask, targets, 0) hidden_in_cluster = mtf.where(cluster_mask, hidden, 0) else: # TODO(mmatena): Unfold the batch dim to get a super long sequence dim # to reduce the risk of overflowing the projection. proj_to_cluster_len = cluster.get_project_to_cluster_length( cluster_mask, dtype=targets.dtype) targets_in_cluster = mtf.einsum( [proj_to_cluster_len, targets], reduced_dims=[targets.shape.get_dim_by_name("length")]) hidden_in_cluster = mtf.einsum( [mtf.cast(proj_to_cluster_len, hidden.dtype), hidden], reduced_dims=[hidden.shape.get_dim_by_name("length")]) loss += cluster.compute_loss(decoder, hidden_in_cluster, targets_in_cluster, context) tail_clusters_dim = mtf.Dimension("tail_clusters", len(tail_cluster_masks)) tail_node_targets = mtf.reduce_sum( mtf.stack([(self._head_cluster.end_token_id + i) * mtf.cast(mask, targets.dtype) for i, mask in enumerate(tail_cluster_masks)], tail_clusters_dim.name), reduced_dim=tail_clusters_dim) head_targets = mtf.where(mtf.cast(tail_node_targets, tf.bool), tail_node_targets, targets) loss += self._head_cluster.compute_loss(decoder, hidden, head_targets, context) return loss
def biasmask_attn_weights(mesh, nd, ns, variable_dtype): # The old mask_attn_weights applied directly to the QK; # this returns a bias that the attention code from mtf adds to the attention matrix. # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. # n_src and n_dest are both the same, i.e equal to sequence length # We rename ns because we want bias to have shape [batch, heads, memory_length, sequence] to match up with QK^T # Information flows from k and v (memory_length) to q (sequence) i = mtf.range(mesh, nd, tf.int32) + ns.size - nd.size j = mtf.range(mesh, ns, tf.int32) i, j = map(lambda t: mtf.broadcast(t, [nd, ns]), (i, j)) dtype = variable_dtype.activation_dtype return mtf.cast(mtf.less(i, j), dtype) * -1e10
def compute_mask(self, context, memory_position): """Compute attention mask. Args: context: a transformer.Context memory_position: an int32 tensor containing memory_length dimension. Returns: a Tensor or None """ masks = [] min_relative_position = self.min_relative_position(context) max_relative_position = self.max_relative_position(context) if max_relative_position is not None or min_relative_position is not None: relative_position = memory_position - context.position if min_relative_position is not None: illegal = mtf.less(relative_position, min_relative_position) masks.append( mtf.cast(illegal, context.activation_dtype) * -1e9) if max_relative_position is not None: illegal = mtf.greater(relative_position, max_relative_position) masks.append( mtf.cast(illegal, context.activation_dtype) * -1e9) sequence_id = None # Subsequence id should only be set if we are in the decoder and have # multiple targets per input. This will allow each sub-target to only attend # to itself. if isinstance(context.subsequence_id, mtf.Tensor): sequence_id = context.subsequence_id elif isinstance(context.sequence_id, mtf.Tensor): sequence_id = context.sequence_id if (sequence_id is not None and context.length_dim in sequence_id.shape): masks.append( mtf.cast( mtf.not_equal( sequence_id, self.rename_length_to_memory_length( sequence_id, context)), context.activation_dtype) * -1e9) return mtf.add_n(masks) if masks else None
def visibility_mask_to_attention_bias(visible, dtype): """Convert a boolean visibility mask to an attention bias. The returned Tensor has large negative values in positions where visible=False. Args: visible: a boolean Tensor dtype: a dtype Returns: a Tensor with the given dtype and the same shape as "visible" """ return mtf.cast(mtf.logical_not(visible), dtype) * -1e9
def resnet34(x, classes_dim, float16=None, batch_norm=False): if float16: x = mtf.cast(x, dtype=tf.float16) logger.debug("[input tensor] (name,shape):({},{})".format(x.name, x.shape)) x = backbone(x, layerlist=[3, 4, 6, 3], chalist=[64, 128, 256, 512], strilist=[1, 2, 2, 2], classes_dim=classes_dim, blocklist=[BasicBlockWithDown, BasicBlock], float16=float16, batch_norm=batch_norm) return x
def resnet152(x, classes_dim, float16=None, batch_norm=False): if float16: x = mtf.cast(x, dtype=tf.float16) logger.debug("[input tensor] (name,shape):({},{})".format(x.name, x.shape)) x = backbone(x, layerlist=[3, 8, 36, 3], chalist=[256, 512, 1024, 2048], strilist=[1, 2, 2, 2], classes_dim=classes_dim, blocklist=[ResidualBlockWithDown, ResidualBlock], float16=float16, batch_norm=batch_norm) return x
def _loss(self, logits, labels): with tf.variable_scope("loss_final"): loss_batch = self.loss_fn(logits=logits, targets=labels, vocab_dim=logits.shape[-1], z_loss=0.0) with tf.variable_scope("reduce_mean_final"): loss = mtf.reduce_mean(loss_batch) loss /= self.params.get("num_microbatches", 1) # Convert to train dtype loss = mtf.cast(loss, self.variable_dtype.slice_dtype) return loss, loss_batch # loss batch must be returned for metric fns
def moe(self, x, layout, mesh_shape, input_mask, is_training): """Mixture of experts layer. TODO(noam): clean up the mixture-of-experts code in Transformer. Args: x: layer input layout: a mtf.LayoutRules mesh_shape: a mtf.Shape input_mask: a mtf.Tensor is_training: a boolean Returns: a mtf.Tensor (the layer output) """ hparams = moe.HParams( moe_gating="top_2", moe_num_experts=self.config.moe_num_experts, moe_loss_coef=1e-3, moe_hidden_size=self.config.moe_intermediate_size, moe_group_size=2048, moe_capacity_factor_train=1.25, moe_capacity_factor_eval=8.0, moe_use_second_place_loss=False, moe_second_policy_train="random", moe_second_policy_eval="random", moe_second_threshold_train=0.2, moe_second_threshold_eval=0.2, moe_dropout_rate=0.0, moe_use_experts_attention=False, moe_min_expert_capacity=4) layer_output, loss = moe.transformer_moe_layer_v1( inputs=x, output_dim=self.model_dim, hparams=hparams, train=is_training, variable_dtype=tf.float32, layout=layout, mesh_shape=mesh_shape, nonpadding=(mtf.cast(input_mask, tf.float32) if input_mask else None), activation=get_activation( self.config.feedforward_intermediate_act)) self._extra_losses.append(loss) return layer_output
def _noisy_targets_from_spec(self, targets, noising_spec, losses=None): if noising_spec["type"] == "mask": # Replace a randomly-chosen noising_spec["prob"] of input tokens with 0. return targets * mtf.cast( mtf.greater(mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"]), targets.dtype) elif noising_spec["type"] == "random_zipfian": # Replace a randomly-chosen noising_spec["prob"] of input tokens. # Rather than drawing the replacement tokens uniformly, we sample from # a distribution favoring lower token-ids, assuming that the ids have # been assigned in frequency order. The probability of choosing an # id is proportional to 1/(id+10) logits = mtf.log(1.0 / (mtf.range( targets.mesh, self.targets_vocab_dim, dtype=tf.float32) + 10.0)) logits = mtf.broadcast(logits, new_shape=targets.shape + logits.shape) r = mtf.sample_with_temperature(logits, self.targets_vocab_dim) use_noise = mtf.less( mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"]) return mtf.where(use_noise, r, targets) elif noising_spec["type"] == "transformer": # Train a small transformer to fill in masked out values, then # sample from it. hparams = self._hparams if hparams.mode != tf.estimator.ModeKeys.TRAIN: raise NotImplementedError("Not implemented") noiser_hparams = copy.copy(self._hparams) noiser_hparams.del_hparam("mode") noiser_hparams.override_from_dict(noising_spec["overrides"]) with tf.variable_scope("noiser"): noiser = MtfTransformer( noiser_hparams, mode=hparams.mode, problem_hparams=self._problem_hparams) logits, loss = noiser._mtf_model_fn( # pylint: disable=protected-access self._original_features, targets.mesh) samples = mtf.sample_with_temperature(logits, self.targets_vocab_dim) losses.append(loss) return samples else: raise ValueError("unknown noising spec %s" % noising_spec)
def _sample(self, features, mesh): hparams = self._hparams (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "encdec": inputs = features["inputs"] while len(inputs.shape.as_list()) > 2: inputs = tf.squeeze(inputs, axis=2) actual_batch_size = tf.shape(inputs)[0] actual_length = tf.shape(inputs)[1] inputs = tf.pad( inputs, [[0, hparams.batch_size - actual_batch_size], [0, hparams.max_length - actual_length]]) inputs = self._import_to_batch_by_length( inputs, "inputs", mesh, hparams) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.reshape(positional_embedding_var, mtf.Shape([self.length_dim, self.model_dim]))) encoder_attention_mask = ( mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.encoder_layers, self_attention_mask=encoder_attention_mask) encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) encdec_tensors = [] for layer_num, layer_type in enumerate(hparams.decoder_layers): if layer_type == "enc_att": with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num): q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars( mesh, self.heads_dim, self.model_dim, self.kv_dim, self.master_dtype, self.slice_dtype, self.activation_dtype) k = mtf.einsum( [encoder_output, k_var], mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim])) v = mtf.einsum( [encoder_output, v_var], mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim])) encdec_tensors.append((q_var, o_var, k, v)) else: encdec_tensors.append(None) partial_targets = None elif hparams.transformer_type == "decoder": encdec_tensors = None encoder_output = None encoder_attention_mask = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, hparams.max_length - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh, hparams) else: raise ValueError( "hparams.model_type = %s not yet supported" % hparams.transformer_type) local_attention_window = mtf.Dimension( "local_attention_window", hparams.local_attention_window_size) if hparams.beam_size == 1: ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape(self.batch_dims + [self.heads_dim, local_attention_window, self.kv_dim]) else: beam_dim = mtf.Dimension("beam", hparams.beam_size) ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape(self.batch_dims + [beam_dim, self.heads_dim, local_attention_window, self.kv_dim]) initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) initial_states = [] for layer in hparams.decoder_layers: if layer == "att": initial_states.extend( [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2) elif layer == "local_att": initial_states.extend( [mtf.zeros(mesh, local_kv_shape, dtype=self.activation_dtype)] * 2) def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_states = self._layer_stack( x, hparams.decoder_layers, encdec_attention_mask=encoder_attention_mask, step_num=step_num, encdec_tensors=encdec_tensors, states=states) logits = mtf.matmul(x, softmax_var) return logits, new_states if hparams.beam_size == 1: temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) return mtf.beam_search.greedy_decode( logits_fn, initial_ids, temperature=temperature, initial_states=initial_states, forced_ids=partial_targets, use_tpu=hparams.use_tpu) else: if hparams.transformer_type == "encdec": input_length = mtf.reduce_sum( mtf.to_float(mtf.cast(inputs, tf.bool)), reduced_dim=self.length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * hparams.decode_length_multiplier + hparams.decode_length_constant, tf.int32) else: decode_length = None beams, unused_scores = mtf.beam_search.beam_search( logits_fn, initial_ids, hparams.alpha, states=initial_states, decode_length=decode_length, use_tpu=hparams.use_tpu, dtype=self.activation_dtype) return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)