def attention_internal(self, context, x, m, q, k, v, memory_length, bias): p = mtf.einsum([q, k], reduced_dims=[self.key_dim]) logits = self.talking_heads( context, p, "logits", self.key_heads_dims, self.softmax_heads_dims, dynamic_projections_from=( ([x] if "x2l" in self.dynamic_projections else []) + ([m] if "m2l" in self.dynamic_projections else []))) if bias is not None: logits += bias h = mtf.softmax(logits, memory_length) weights = self.talking_heads( context, h, "weights", self.softmax_heads_dims, self.value_heads_dims, dynamic_projections_from=( ([x] if "x2w" in self.dynamic_projections else []) + ([m] if "m2w" in self.dynamic_projections else []))) # TODO(noam): make dropout_broadcast_dims configurable dropout_broadcast_dims = [context.length_dim] weights = mtf.dropout(weights, rate=self.dropout_rate if context.train else 0.0, noise_shape=weights.shape - dropout_broadcast_dims) u = mtf.einsum([weights, v], reduced_dims=[memory_length]) return self.compute_y(context, u)
def testOptimizeLayoutRepetition(self): x1 = mtf.zeros(self.mesh, "a:10,b:5") x2 = mtf.zeros(self.mesh, "b:5,c:20") for _ in six.moves.xrange(100): mtf.einsum([x1, x2], "a:10,c:20") optimizer = self.get_layout_optimizer() self.assertGreaterEqual( len(list(optimizer._graph.get_all_operation_names())), 50) self.assertLessEqual(len(optimizer._model.Proto().variables), 50) # Same checks. layout = optimizer.solve() self.assertEqual(layout, "a:m2;c:m1") layout_value = optimizer.evaluate_layout(layout) self.assertLessEqual(layout_value, optimizer.evaluate_layout("a:m1;b:m2")) self.assertLessEqual(layout_value, optimizer.evaluate_layout("a:m1;c:m2")) self.assertLessEqual(layout_value, optimizer.evaluate_layout("b:m1;a:m2")) self.assertLessEqual(layout_value, optimizer.evaluate_layout("b:m1;c:m2")) self.assertLessEqual(layout_value, optimizer.evaluate_layout("c:m1;b:m2")) self.assertEqual(layout_value, optimizer.evaluate_layout("c:m1;a:m2"))
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ # tf_images is a tf.Tensor with shape [batch, 28, 28] and dtype tf.float32 # tf_labels is a tf.Tensor with shape [batch] and dtype tf.int32 batch_dim = mtf.Dimension("batch", 100) rows_dim = mtf.Dimension("rows", 28) cols_dim = mtf.Dimension("cols", 28) hidden_dim = mtf.Dimension("hidden", 1024) classes_dim = mtf.Dimension("classes", 10) images = mtf.import_tf_tensor(mesh, image, shape=[batch_dim, rows_dim, cols_dim]) labels = mtf.import_tf_tensor(mesh, labels, [batch_dim]) w1 = mtf.get_variable(mesh, "w1", [rows_dim, cols_dim, hidden_dim]) w2 = mtf.get_variable(mesh, "w2", [hidden_dim, classes_dim]) # einsum is a generalization of matrix multiplication (see numpy.einsum) hidden = mtf.relu( mtf.einsum(images, w1, output_shape=[batch_dim, hidden_dim])) logits = mtf.einsum(hidden, w2, output_shape=[batch_dim, classes_dim]) loss = mtf.reduce_mean( mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim)) return logits, loss
def testOptimizeLayoutTiebreak(self): x1 = mtf.zeros(self.mesh, "a:10,b:5") x2 = mtf.zeros(self.mesh, "b:5,c:20") mtf.einsum([x1, x2], "a:10,c:20") # Rewrite mesh_shape to have a dummy dimension. self.mesh_shape = mtf.convert_to_shape("m1:4,m2:2,m3:1") optimizer = self.get_layout_optimizer() layout = optimizer.solve() self.assertEqual(layout, "a:m2;b:m3;c:m1")
def hidden_to_logits(self, hidden): hidden *= self._output_dim.size**-0.5 if self._is_factorized: tmp = mtf.einsum([hidden, self._factor2], reduced_dims=[self._output_dim]) return mtf.einsum([tmp, self._factor1], reduced_dims=[self._inner_dim]) else: return mtf.einsum([hidden, self._embedding_weights], reduced_dims=[self._output_dim])
def attention_internal(self, context, q, m, memory_length, bias): logits = mtf.einsum([q, m], reduced_dims=[context.model.model_dim]) if bias is not None: logits += bias weights = mtf.softmax(logits, memory_length) # TODO(noam): make dropout_broadcast_dims configurable dropout_broadcast_dims = [context.length_dim] weights = mtf.dropout( weights, rate=self.dropout_rate if context.train else 0.0, noise_shape=weights.shape - dropout_broadcast_dims) u = mtf.einsum([weights, m], reduced_dims=[memory_length]) return self.compute_y(context, u)
def mnist_model(image, labels, mesh, hs_t): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh hs_t: a mtf.Tensor with shape [batch, hidden_1] Returns: logits: a mtf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] hs_t: an updated mtf.Tensor """ input_num = 28 timesteps_num = 28 classes_num = 10 batch_dim = mtf.Dimension("batch", FLAGS.batch_size) input_dim = mtf.Dimension("input", input_num) timesteps_dim = mtf.Dimension("timesteps", timesteps_num) classes_dim = mtf.Dimension("classes", classes_num) hidden_dim_1 = mtf.Dimension("hidden_1", FLAGS.hidden_size) hidden_dim_2 = mtf.Dimension("hidden_2", FLAGS.hidden_size) x = mtf.import_tf_tensor(mesh, tf.reshape(image, [FLAGS.batch_size, 28, 28]), [batch_dim, timesteps_dim, input_dim]) y = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]), [batch_dim]) hs_t = mtf.import_tf_tensor(mesh, hs_t, [batch_dim, hidden_dim_1]) Wxh = mtf.get_variable(mesh, "Wxh", [input_dim, hidden_dim_2]) Whh = mtf.get_variable(mesh, "Whh", [hidden_dim_1, hidden_dim_2]) Why = mtf.get_variable(mesh, "Why", [hidden_dim_2, classes_dim]) bh = mtf.get_variable(mesh, "bh", [hidden_dim_2]) by = mtf.get_variable(mesh, "by", [classes_dim]) x_list = mtf.unstack(x, timesteps_dim) for xs_t in x_list: hs_t = mtf.tanh( mtf.einsum([xs_t, Wxh], [batch_dim, hidden_dim_2]) + mtf.einsum([hs_t, Whh], [batch_dim, hidden_dim_2]) + bh) logits = mtf.einsum([hs_t, Why], [batch_dim, classes_dim]) + by if labels is None: loss = None else: loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(y, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss, hs_t
def linear_attention(q, k, v): batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3]) q = mtf.rename_dimension(q, "features_per_head", "features_per_head_in") k = mtf.rename_dimension(k, "features_per_head", "features_per_head_in") dim_in = k.shape[-1] q = mtf.softmax(q, dim_in) k = mtf.softmax(k, seq_dim) context = mtf.einsum([k, v], output_shape=[batch_dim, head_dim, dim_in, dim_out]) attn = mtf.einsum([q, context], output_shape=[batch_dim, seq_dim, head_dim, dim_out]) return attn
def call(self, context, x, losses=None): """Call the layer.""" wq, wk, wv, wo = mtf.layers.multihead_attention_params( context.mesh, self.heads_dim, context.model_dim, self.kv_dim, context.variable_dtype) memory_length = mtf.Dimension("memory_length", context.length_dim.size) q = mtf.einsum([x, wq], reduced_dims=[context.model_dim]) if context.mode == "incremental": m = x else: m = mtf.rename_dimension(x, context.length_dim.name, "memory_length") k = mtf.einsum([m, wk], reduced_dims=[context.model_dim]) v = mtf.einsum([m, wv], reduced_dims=[context.model_dim]) if context.mode == "incremental": old_k, old_v = context.get_states(2) one_hot = mtf.one_hot(context.position, memory_length, dtype=context.activation_dtype) inv_one_hot = 1.0 - one_hot k = old_k * inv_one_hot + k * one_hot v = old_v * inv_one_hot + v * one_hot if context.mode == "incremental" or context.mode == "first_part": context.record_new_states([k, v]) masks = [] if context.autoregressive: masks.append( mtf.cast( mtf.less( context.position, mtf.range(context.mesh, memory_length, dtype=tf.int32)), context.activation_dtype) * -1e9) if (context.sequence_id is not None and isinstance(context.sequence_id, mtf.Tensor) and context.length_dim in context.sequence_id.shape): masks.append( mtf.cast( mtf.not_equal( context.sequence_id, mtf.layers.rename_length_to_memory_length( context.sequence_id)), context.activation_dtype) * -1e9) mask = mtf.add_n(masks) if masks else None o = mtf.layers.dot_product_attention_v2( q, k, v, memory_length, self.kv_dim, self.kv_dim, mask, self.dropout_rate if context.train else 0.0, [context.length_dim]) return mtf.einsum([o, wo], x.shape, reduced_dims=[self.heads_dim, self.kv_dim])
def attention(q, k, v, memory_length_dim, key_dim, value_dim, bias=None, dropout_rate=0.0, dropout_broadcast_dims=None, extra_logit=None): """Dot-product attention - doesn't use positional dimensions. key_dim is a Dimension representing the channels in the queries and keys value_dim is a Dimension representing the channels in values memory_length_dim is a Dimension representing the different key/value pairs. Dimensions of q: other_query_dims + {key_dim} Dimensions of k: other_memory_dims + {memory_length_dim, key_dim} Dimensions of v: other_memory_dims + {memory_length_dim, value_dim} other_memory_dims is a subset of other_query_dims Typically, other_query_dims={batch, heads, length} Typically, other_memory_dims={batch, heads} Args: q: a Tensor k: a Tensor v: a Tensor memory_length_dim: a Dimension key_dim: a Dimension value_dim: a Dimension bias: a Tensor to be added into the attention logits. dropout_rate: a float. dropout_broadcast_dims: an optional list of mtf.Dimension extra_logit: an optional scalar or tensor Returns: Tensor with shape q.shape - key_dim + value_dim """ logits = mtf.einsum([q, k], reduced_dims=[key_dim]) if bias is not None: logits += bias weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit) if dropout_rate != 0.0: weights = mtf.dropout(weights, 1.0 - dropout_rate, noise_shape=weights.shape - dropout_broadcast_dims) outputs_shape = q.shape - key_dim + value_dim outputs = mtf.einsum([weights, v], outputs_shape) return outputs
def testLayout(self): # Construct a Mesh TensorFlow graph and mesh. mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, "my_mesh") x = mtf.zeros(mesh, "a:10,b:5") y = mtf.zeros(mesh, "b:5,c:20") z = mtf.einsum([x, y], "a:10,c:20") # Decide on a mesh shape. mesh_shape = mtf.convert_to_shape("m1:4,m2:2") # Compute a layout based on the graph and mesh. # Note that knowing the identity of the outputs is important to the # optimization since they cannot be freed. layout = mtf.auto_mtf.layout(mtf_graph, mesh_shape, [z]) a_dim = mtf.convert_to_dimension(("a", 10)) b_dim = mtf.convert_to_dimension(("b", 5)) c_dim = mtf.convert_to_dimension(("c", 20)) self.assertEqual( layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape), 1) self.assertIsNone( layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape)) self.assertEqual( layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape), 0)
def ids_to_embedding(self, ids): if self._is_factorized: tmp = mtf.gather(self._factor1, ids, self._vocab_dim) return mtf.einsum([tmp, self._factor2], reduced_dims=[self._inner_dim]) else: return mtf.gather(self._embedding_weights, ids, self._vocab_dim)
def testLayoutAndMeshShape(self): # Same as previous test, but don't specify a 4x2 mesh. mtf_graph = mtf.Graph() mesh = mtf.Mesh(mtf_graph, "my_mesh") x = mtf.zeros(mesh, "a:10,b:5") y = mtf.zeros(mesh, "b:5,c:20") z = mtf.einsum([x, y], "a:10,c:20") layout, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape(mtf_graph, 8, [z]) a_dim = mtf.convert_to_dimension(("a", 10)) b_dim = mtf.convert_to_dimension(("b", 5)) c_dim = mtf.convert_to_dimension(("c", 20)) self.assertEqual(layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape), 1) self.assertIsNone(layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape)) self.assertEqual(layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape), 0) self.assertCountEqual(mesh_shape.dims, [mtf.Dimension("mesh_0", 4), mtf.Dimension("mesh_1", 2)]) layout, mesh_shape = mtf.auto_mtf.layout_and_mesh_shape( mtf_graph, 8, [z], 1) self.assertIsNone(layout.tensor_dimension_to_mesh_axis(a_dim, mesh_shape)) self.assertIsNone(layout.tensor_dimension_to_mesh_axis(b_dim, mesh_shape)) self.assertIsNone(layout.tensor_dimension_to_mesh_axis(c_dim, mesh_shape)) self.assertCountEqual(mesh_shape.dims, [mtf.Dimension("mesh_0", 8)])
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a tf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) rows_dim = mtf.Dimension("rows", 28) cols_dim = mtf.Dimension("cols", 28) classes_dim = mtf.Dimension("classes", 10) x = mtf.import_tf_tensor(mesh, tf.reshape(image, [FLAGS.batch_size, 28, 28]), [batch_dim, rows_dim, cols_dim]) y = mtf.import_tf_tensor(mesh, tf.reshape(labels, [FLAGS.batch_size]), [batch_dim]) w1 = mtf.get_variable(mesh, "w1", [rows_dim, cols_dim, classes_dim]) b1 = mtf.get_variable(mesh, "b1", [classes_dim]) logits = mtf.relu(mtf.einsum([x, w1], [batch_dim, classes_dim]) + b1) if labels is None: loss = None else: loss = mtf.layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(y, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def get_indices(self, keys: mtf.Tensor, query: mtf.Tensor) -> Tuple[mtf.Tensor, mtf.Tensor]: """Generate score and indices for the query.""" score_shape = mtf.Shape(query.shape.dims[:-1] + keys.shape.dims[2:3]) scores = mtf.einsum([query, keys], output_shape=score_shape) # [b, l, h, 2, n_keys] knn_dim = mtf.Dimension("knn", self.knn) scores, indices = mtf.top_k(scores, score_shape.dims[-1], knn_dim) # [b, l, h, 2, knn] # Computes the top cartesian products and their indices knn_square_dim = mtf.Dimension("knn_square_dim", self.knn**2) scores1, scores2 = mtf.unstack(scores, scores.shape.dims[-2]) scores2 = mtf.rename_dimension(scores2, "knn", "knn2") out_shape = mtf.Shape(scores1.shape.dims + scores2.shape.dims[-1:]) all_scores = mtf.add(scores1, scores2, output_shape=out_shape) all_scores = mtf.replace_dimensions(all_scores, out_shape[-2:], knn_square_dim) indices1, indices2 = mtf.unstack(indices, indices.shape.dims[-2]) indices1 = mtf.multiply(indices1, self.n_keys) indices2 = mtf.rename_dimension(indices2, "knn", "knn2") all_indices = mtf.add(indices1, indices2, output_shape=out_shape) all_indices = mtf.replace_dimensions(all_indices, out_shape[-2:], knn_square_dim) scores, best_indices = mtf.top_k(all_scores, all_scores.shape.dims[-1], knn_dim) return scores, mtf.gather(all_indices, best_indices, knn_square_dim)
def causal_linear_attention(q, k, v, epsilon=1e-6): batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3]) q = mtf.rename_dimension(q, "features_per_head", "features_per_head_in") k = mtf.rename_dimension(k, "features_per_head", "features_per_head_in") dim_in = k.shape[-1] q = mtf.softmax(q, dim_in) k = mtf.exp(k) cumulative_k = mtf.cumsum(k, seq_dim) context = mtf.einsum([k, v], output_shape=[batch_dim, seq_dim, head_dim, dim_in, dim_out]) cumulative_context = mtf.cumsum(context, seq_dim) cumulative_context /= (cumulative_k + epsilon) attn = mtf.einsum([q, cumulative_context], output_shape=[batch_dim, seq_dim, head_dim, dim_out]) return attn
def call(self, context, x, losses=None): """Call the layer.""" params = mtf.layers.multihead_attention_params(context.mesh, self.heads_dim, context.model_dim, self.kv_dim, context.variable_dtype) if context.mode == "incremental": prev_k, prev_v = context.get_states(2) y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental( x, prev_k, prev_v, context.position, params=params) context.record_new_states([new_k, new_v]) return y else: kv = [] y = mtf.layers.masked_local_attention_1d(x, self.kv_dim, self.heads_dim, self.window_size, params=params, return_kv=kv) if context.mode == "first_part": k = kv[0] v = kv[1] window_dim = mtf.Dimension("window", self.window_size) mesh = k.mesh window_pos = mtf.range(mesh, window_dim, tf.int32) pos = mtf.range(mesh, context.length_dim, tf.int32) select_recent = mtf.cast( mtf.equal(window_pos, mtf.mod(pos, self.window_size)), k.dtype) select_recent *= mtf.cast( mtf.less(pos, context.initial_position), k.dtype) select_recent *= mtf.cast( mtf.greater_equal( pos, context.initial_position - self.window_size), k.dtype) state_shape = k.shape.dims[:-2] + [window_dim, self.kv_dim] k_state = mtf.einsum([k, select_recent], output_shape=state_shape, reduced_dims=[context.length_dim]) v_state = mtf.einsum([v, select_recent], output_shape=state_shape, reduced_dims=[context.length_dim]) context.new_states.extend([k_state, v_state]) return y
def testOptimizeLayout(self): x1 = mtf.zeros(self.mesh, "a:10,b:5") x2 = mtf.zeros(self.mesh, "b:5,c:20") mtf.einsum([x1, x2], "a:10,c:20") optimizer = self.get_layout_optimizer() # Cut dimensions to make them equally sized. layout = optimizer.solve() self.assertEqual(layout, "a:m2;c:m1") # This optimal layout should have the lowest value. layout_value = optimizer.evaluate_layout(layout) self.assertLessEqual(layout_value, optimizer.evaluate_layout("a:m1;b:m2")) self.assertLessEqual(layout_value, optimizer.evaluate_layout("a:m1;c:m2")) self.assertLessEqual(layout_value, optimizer.evaluate_layout("b:m1;a:m2")) self.assertLessEqual(layout_value, optimizer.evaluate_layout("b:m1;c:m2")) self.assertLessEqual(layout_value, optimizer.evaluate_layout("c:m1;b:m2")) self.assertEqual(layout_value, optimizer.evaluate_layout("c:m1;a:m2"))
def talking_heads( self, context, inp, name, input_heads_dims, output_heads_dims, dynamic_projections_from=None): shared_dims = [d for d in input_heads_dims if d in output_heads_dims] reduced_dims = [d for d in input_heads_dims if d not in output_heads_dims] new_dims = [d for d in output_heads_dims if d not in input_heads_dims] if not (reduced_dims or new_dims): # Output dimensions are same as input dimensions. Return the input return inp elif dynamic_projections_from: # There are one or more dynamic talking-heads-projections with tf.variable_scope(name): # static projection - this is the same as the static projection in the # "else" case below. We create the weight matrix with get_variable # instead of calling mtf.layers.dense() so that we can fold the # static projection into one of the dynamic projections. static_p_initializer = mtf.layers.VarianceScalingInitializer()( reduced_dims, new_dims) static_p_shape = ( context.model.ensemble_dims + shared_dims + reduced_dims + new_dims) static_p = mtf.get_variable(inp.mesh, "kernel", static_p_shape, initializer=static_p_initializer, dtype=context.variable_dtype) ps = [] for i, dp_from in enumerate(dynamic_projections_from): kernel_initializer = mtf.layers.VarianceScalingInitializer( self.dynamic_projections_init_scale / mtf.Shape(reduced_dims).size) ps.append( mtf.layers.dense( dp_from, reduced_dims=[context.model.model_dim], new_dims=shared_dims + reduced_dims + new_dims, use_bias=False, activation=None, variable_dtype=context.variable_dtype, name="%s_dynamic_%d" % (name, i), expert_dims=context.model.ensemble_dims, kernel_initializer=kernel_initializer)) # Fold the static projection into one of the static projections. # Mathematically, we could add all the dynamic projections together # here, but it would create a very large tensor which contained # both the query-length and memory-length dimensions, and would # probably be slower in practice. ps[0] += static_p return mtf.add_n( [mtf.einsum([inp, p], reduced_dims=reduced_dims) for p in ps]) else: # No dynamic projections. Static talking-heads projection only return mtf.layers.dense( inp, reduced_dims=reduced_dims, new_dims=new_dims, use_bias=False, activation=None, variable_dtype=context.variable_dtype, name=name, expert_dims=context.model.ensemble_dims + shared_dims)
def compute_loss(self, decoder: transformer.Unitransformer, hidden: mtf.Tensor, targets: mtf.Tensor, context: transformer.Context) -> mtf.Tensor: """Returns the loss without computing a softmax over the entire vocab.""" loss = 0 tail_cluster_masks = [] for cluster in self._tail_clusters: cluster_mask = cluster.get_cluster_mask(targets) tail_cluster_masks.append(cluster_mask) if cluster.length_projection_factor == 1: targets_in_cluster = mtf.where(cluster_mask, targets, 0) hidden_in_cluster = mtf.where(cluster_mask, hidden, 0) else: # TODO(mmatena): Unfold the batch dim to get a super long sequence dim # to reduce the risk of overflowing the projection. proj_to_cluster_len = cluster.get_project_to_cluster_length( cluster_mask, dtype=targets.dtype) targets_in_cluster = mtf.einsum( [proj_to_cluster_len, targets], reduced_dims=[targets.shape.get_dim_by_name("length")]) hidden_in_cluster = mtf.einsum( [mtf.cast(proj_to_cluster_len, hidden.dtype), hidden], reduced_dims=[hidden.shape.get_dim_by_name("length")]) loss += cluster.compute_loss(decoder, hidden_in_cluster, targets_in_cluster, context) tail_clusters_dim = mtf.Dimension("tail_clusters", len(tail_cluster_masks)) tail_node_targets = mtf.reduce_sum( mtf.stack([(self._head_cluster.end_token_id + i) * mtf.cast(mask, targets.dtype) for i, mask in enumerate(tail_cluster_masks)], tail_clusters_dim.name), reduced_dim=tail_clusters_dim) head_targets = mtf.where(mtf.cast(tail_node_targets, tf.bool), tail_node_targets, targets) loss += self._head_cluster.compute_loss(decoder, hidden, head_targets, context) return loss
def call(self, context, x, losses=None): """Call the layer.""" memory_input_dim = context.encoder_output.shape[-1] if memory_input_dim != context.model_dim: raise NotImplementedError( "TODO(noam): support different model_dim in encoder and decoder." ) wq, wk, wv, wo = mtf.layers.multihead_attention_params( context.mesh, self.heads_dim, context.model_dim, self.kv_dim, context.variable_dtype) q = mtf.einsum([x, wq], reduced_dims=[context.model_dim]) if context.mode == "incremental": k, v, memory_length = context.get_constant_state() else: m = context.encoder_output memory_length, = [ d for d in m.shape.dims if d.name == "memory_length" ] k = mtf.einsum([m, wk], reduced_dims=[context.model_dim]) v = mtf.einsum([m, wv], reduced_dims=[context.model_dim]) if context.mode == "first_part": context.record_constant_state((k, v, memory_length)) if context.encoder_sequence_id and context.sequence_id: mask = mtf.cast( mtf.not_equal(context.sequence_id, context.encoder_sequence_id), context.activation_dtype) * -1e9 else: mask = None o = mtf.layers.dot_product_attention_v2( q, k, v, memory_length, self.kv_dim, self.kv_dim, mask, dropout=self.dropout_rate if context.train else 0.0, dropout_broadcast_dims=[context.length_dim]) return mtf.einsum([o, wo], x.shape, reduced_dims=[self.heads_dim, self.kv_dim])
def hidden_to_logits(self, hidden: mtf.Tensor, context: transformer.Context) -> mtf.Tensor: """Function called by mtf transformer to get the logits. Note that we are taking the log of a mixture of softmaxes. The logits will then go through a softmax. This could potentially run into numerical stability issues. If that happens, try setting the activation_dtype to float32. Args: hidden: hidden model states of the final decoder layer. context: the context used for the call to the transformer. Returns: The logits. """ del context hidden *= self._output_dim.size**-0.5 component_prior_logits = mtf.einsum([hidden, self._mixture_weights], reduced_dims=[self._output_dim]) component_contexts = mtf.einsum([ mtf.rename_dimension(hidden, self._output_dim.name, self._copy_output_dim.name), self._context_weights, ], reduced_dims=[self._copy_output_dim]) component_contexts = mtf.tanh(component_contexts) component_logits = mtf.einsum( [component_contexts, self._embedding_weights], reduced_dims=[self._output_dim]) component_prior_logits = mtf.log_softmax( component_prior_logits, reduced_dim=self._components_dim) component_logits = mtf.log_softmax(component_logits, reduced_dim=self._vocab_dim) logits = component_prior_logits + component_logits logits = mtf.reduce_logsumexp(logits, reduced_dim=self._components_dim) return logits
def get_log_softmax_prefix(self, log_softmax, end_index): """Returns first end_index entries in log_softmax along the vocab dim.""" prefix_dim = mtf.Dimension(self._vocab_dim.name, end_index) indices = mtf.mtf_range( log_softmax.mesh, dim=self._vocab_dim, dtype=tf.int32) prefix_indices = mtf.where(mtf.less(indices, end_index), indices, -1) projection = mtf.one_hot( prefix_indices, prefix_dim, dtype=log_softmax.dtype) return mtf.einsum([log_softmax, projection], reduced_dims=[self._vocab_dim])
def wide(x, mask, float16=None): x = mtf.einsum([x,mask],output_shape=[x.shape.dims[0],x.shape.dims[-1]], name='wide_mul') logger.debug("[output tensor] (name,shape):({},{})".format(x.name,x.shape)) if float16: wide_b = np.array(0,dtype=np.float16) else: wide_b = np.array(0,dtype=np.float32) x = mtf.add(x,wide_b,name="wide_sum") logger.debug("[output tensor] (name,shape):({},{})".format(x.name,x.shape)) return x
def deep(x, mask, float16=None): x = mtf.einsum([x, mask], output_shape=x.shape.dims, name='deep_mul') logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) # 使用仿照mindspore中使用fp16来计算下面的dense x = mtf.cast(x, dtype=tf.float16) x = mtf.layers.dense(x, mtf.Dimension(name='dense_dim0', size=1024), name="deep-dense-0", reduced_dims=x.shape.dims[-2:], activation=mtf.relu, variable_dtype=mtf.VariableDType( tf.float16, tf.float16, tf.float16)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) x = mtf.layers.dense(x, mtf.Dimension(name='dense_dim1', size=512), name="deep-dense-1", activation=mtf.relu, variable_dtype=mtf.VariableDType( tf.float16, tf.float16, tf.float16)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) x = mtf.layers.dense(x, mtf.Dimension(name='dense_dim2', size=256), name="deep-dense-2", activation=mtf.relu, variable_dtype=mtf.VariableDType( tf.float16, tf.float16, tf.float16)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) x = mtf.layers.dense(x, mtf.Dimension(name='dense_dim3', size=128), name="deep-dense-3", activation=mtf.relu, variable_dtype=mtf.VariableDType( tf.float16, tf.float16, tf.float16)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) x = mtf.layers.dense(x, mtf.Dimension(name='dense_dim4', size=1), name="deep-dense-4", variable_dtype=mtf.VariableDType( tf.float16, tf.float16, tf.float16)) logger.debug("[output tensor] (name,shape):({},{})".format( x.name, x.shape)) if float16: pass else: x = mtf.cast(x, dtype=tf.float32) return x
def lpt_init_single(lr_field, a0, kvec_lr, halo_size, lr_shape, hr_shape, part_shape, antialias=True, order=1, post_filtering=True, cosmology=Planck15): a = a0 batch_dim = lr_field.shape[0] lnc = lr_shape[-1].size # Create particles on the high resolution grid mstate = mesh_ops.mtf_indices(lr_field.mesh, shape=part_shape, dtype=tf.float32) X = mtf.einsum([mtf.ones(lr_field.mesh, [batch_dim]), mstate], output_shape=[batch_dim] + mstate.shape[:]) k_dims_lr = [d.shape[0] for d in kvec_lr] k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]] lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr) grad_kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr) # Reorder the low res FFTs which where transposed# y,z,x grad_kfield_lr = [grad_kfield_lr[2], grad_kfield_lr[0], grad_kfield_lr[1]] displacement = [] for f in grad_kfield_lr: f = mesh_utils.c2r3d(f, lr_shape[-3:]) f = mtf.slicewise(lambda x:tf.expand_dims(tf.expand_dims(tf.expand_dims(x, axis=1),axis=1),axis=1), [f], output_dtype=tf.float32, output_shape=mtf.Shape(hr_shape[0:4]+[ mtf.Dimension('sx_block', lnc//hr_shape[1].size), mtf.Dimension('sy_block', lnc//hr_shape[2].size), mtf.Dimension('sz_block', lnc//hr_shape[3].size)]), name='my_reshape', splittable_dims=lr_shape[:-1]+hr_shape[1:4]+part_shape[1:3]) for block_size_dim in hr_shape[-3:]: f = mtf.pad(f, [halo_size, halo_size], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]): f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim, halo_size) d = mesh_utils.cic_readout(f, X, halo_size) displacement.append(d) # Readout to particle positions displacement = mtf.stack([ d for d in displacement],"ndim",axis=4) pt = PerturbationGrowth(cosmology, a=[a], a_normalize=1.0) DX = pt.D1(a) * displacement P = (a ** 2 * pt.f1(a) * pt.E(a)) * DX F = (a ** 2 * pt.E(a) * pt.gf(a) / pt.D1(a)) * DX # TODO: Implement 2nd order LPT # Moves the particles according to displacement X = X + DX return X, P, F
def _get_decoder_inputs(self, context): """Computes the inputs to the decoder when using transparent attention. We must cache on the context in order to ensure that we are not replicating variables when the layer's call function is called in different tf variable scopes. Args: context: a Context Returns: a list containing `self.num_decoder_modules` of tensors with shape [<batch_dims>, length_dim, output_vocab_dim] """ if hasattr(context, "decoder_layers_per_module"): return context.decoder_layers_per_module encoder_layer_outputs = [ mtf.layers.rename_length_to_memory_length(output) for output in context.encoder_layer_outputs ] layers_per_module = self.layers_per_encoder_module encoder_module_outputs_dim = mtf.Dimension( "encoder_module_outputs", size=self.encoder_num_modules + 1) decoder_module_inputs_dim = mtf.Dimension( "decoder_module_inputs", size=self.decoder_num_modules) encoder_module_outputs = mtf.stack( [encoder_layer_outputs[0]] + encoder_layer_outputs[layers_per_module::layers_per_module], dim_name="encoder_module_outputs") w = mtf.get_variable( context.mesh, "w", mtf.Shape([encoder_module_outputs_dim, decoder_module_inputs_dim]), initializer=tf.random_normal_initializer( stddev=(encoder_module_outputs_dim.size * decoder_module_inputs_dim.size)**-0.5), dtype=context.variable_dtype) if context.train and self.dropout_rate != 0.0: w = mtf.dropout(w, 1.0 - self.dropout_rate) s = mtf.softmax(w, reduced_dim=encoder_module_outputs_dim) z = mtf.einsum([s, encoder_module_outputs], reduced_dims=[encoder_module_outputs_dim]) input_per_decoder = mtf.split( z, split_dim=decoder_module_inputs_dim, num_or_size_splits=decoder_module_inputs_dim.size) context.decoder_layers_per_module = [ mtf.reshape(inpt, z.shape.dims[1:]) for inpt in input_per_decoder ] return context.decoder_layers_per_module
def compute_q(self, query_antecedent): """Compute query Tensor q. Args: query_antecedent: a Tensor with dimensions {query_input_dim} + other_dims Returns: a Tensor with dimensions query_heads_dims + {key_dim} + other_dims """ ret = mtf.einsum([query_antecedent, self.wq], reduced_dims=[self.query_input_dim]) if self.combine_dims: ret = mtf.replace_dimensions(ret, ret.shape.dims[-1], self.q_dims) return ret
def compute_v(self, memory_antecedent): """Compute value Tensor v. Args: memory_antecedent: a Tensor with dimensions {memory_input_dim} + other_dims Returns: a Tensor with dimensions memory_heads_dims + {value_dim} + other_dims """ if self.shared_kv: raise ValueError("compute_v cannot be called with shared_kv") ret = mtf.einsum([memory_antecedent, self.wv], reduced_dims=[self.memory_input_dim]) if self.combine_dims: ret = mtf.replace_dimensions(ret, ret.shape.dims[-1], self.v_dims) return ret
def call(self, context, x: mtf.Tensor) -> mtf.Tensor: """Call the layer.""" # Initialize Memory Keys and Values n_key_dim = mtf.Dimension("n_keys", self.n_keys) n_value_dim = mtf.Dimension("n_values", self.n_values) key_dim = mtf.Dimension("key", self.key_size // 2) value_dim = x.shape.dims[-1] head_dim = mtf.Dimension("n_heads", self.n_heads) product_dim = mtf.Dimension("product_key", 2) keys = mtf.get_variable( context.mesh, name="keys", shape=mtf.Shape([head_dim, product_dim, n_key_dim, key_dim]), dtype=context.variable_dtype) values = mtf.layers.embedding_weights( context.mesh, vocab_dim=n_value_dim, output_dim=value_dim, variable_dtype=context.variable_dtype, name="values") # Compute query new_dims = [head_dim, product_dim, key_dim] reduce_dims = x.shape.dims[-1:] query = mtf.layers.dense(x, new_dims, reduced_dims=reduce_dims, activation=None, use_bias=True, variable_dtype=context.variable_dtype, name="query") # [b, l, h, 2, k] # Note: We use layer norm instead of batch norm to normalize queries. # The main advantage is that layer norm works well with the codebase # whereas the implementation of batch norm requires handling of tf ops. query = mtf.layers.layer_norm(query, query.shape.dims[-1]) # Retrieve indices and scores scores, indices = self.get_indices(keys, query) # [b, l, h, k] scores = mtf.softmax(scores, reduced_dim=scores.shape.dims[-1]) top_values = mtf.gather(values, indices, n_value_dim) # [b, l, h, k, v] out_values = mtf.einsum( [top_values, scores], reduced_dims=scores.shape.dims[-2:]) # [b, l, v] return out_values
def _sample(self, features, mesh): hparams = self._hparams (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "encdec": inputs = features["inputs"] while len(inputs.shape.as_list()) > 2: inputs = tf.squeeze(inputs, axis=2) actual_batch_size = tf.shape(inputs)[0] actual_length = tf.shape(inputs)[1] inputs = tf.pad( inputs, [[0, hparams.batch_size - actual_batch_size], [0, hparams.max_length - actual_length]]) inputs = self._import_to_batch_by_length( inputs, "inputs", mesh, hparams) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.reshape(positional_embedding_var, mtf.Shape([self.length_dim, self.model_dim]))) encoder_attention_mask = ( mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.encoder_layers, self_attention_mask=encoder_attention_mask) encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) encdec_tensors = [] for layer_num, layer_type in enumerate(hparams.decoder_layers): if layer_type == "enc_att": with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num): q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars( mesh, self.heads_dim, self.model_dim, self.kv_dim, self.master_dtype, self.slice_dtype, self.activation_dtype) k = mtf.einsum( [encoder_output, k_var], mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim])) v = mtf.einsum( [encoder_output, v_var], mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim])) encdec_tensors.append((q_var, o_var, k, v)) else: encdec_tensors.append(None) partial_targets = None elif hparams.transformer_type == "decoder": encdec_tensors = None encoder_output = None encoder_attention_mask = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, hparams.max_length - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh, hparams) else: raise ValueError( "hparams.model_type = %s not yet supported" % hparams.transformer_type) local_attention_window = mtf.Dimension( "local_attention_window", hparams.local_attention_window_size) if hparams.beam_size == 1: ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape(self.batch_dims + [self.heads_dim, local_attention_window, self.kv_dim]) else: beam_dim = mtf.Dimension("beam", hparams.beam_size) ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape(self.batch_dims + [beam_dim, self.heads_dim, local_attention_window, self.kv_dim]) initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) initial_states = [] for layer in hparams.decoder_layers: if layer == "att": initial_states.extend( [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2) elif layer == "local_att": initial_states.extend( [mtf.zeros(mesh, local_kv_shape, dtype=self.activation_dtype)] * 2) def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_states = self._layer_stack( x, hparams.decoder_layers, encdec_attention_mask=encoder_attention_mask, step_num=step_num, encdec_tensors=encdec_tensors, states=states) logits = mtf.matmul(x, softmax_var) return logits, new_states if hparams.beam_size == 1: temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) return mtf.beam_search.greedy_decode( logits_fn, initial_ids, temperature=temperature, initial_states=initial_states, forced_ids=partial_targets, use_tpu=hparams.use_tpu) else: if hparams.transformer_type == "encdec": input_length = mtf.reduce_sum( mtf.to_float(mtf.cast(inputs, tf.bool)), reduced_dim=self.length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * hparams.decode_length_multiplier + hparams.decode_length_constant, tf.int32) else: decode_length = None beams, unused_scores = mtf.beam_search.beam_search( logits_fn, initial_ids, hparams.alpha, states=initial_states, decode_length=decode_length, use_tpu=hparams.use_tpu, dtype=self.activation_dtype) return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)