def linear_attention(q, k, v): batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3]) q = mtf.rename_dimension(q, "features_per_head", "features_per_head_in") k = mtf.rename_dimension(k, "features_per_head", "features_per_head_in") dim_in = k.shape[-1] q = mtf.softmax(q, dim_in) k = mtf.softmax(k, seq_dim) context = mtf.einsum([k, v], output_shape=[batch_dim, head_dim, dim_in, dim_out]) attn = mtf.einsum([q, context], output_shape=[batch_dim, seq_dim, head_dim, dim_out]) return attn
def attention_internal(self, context, x, m, q, k, v, memory_length, bias): p = mtf.einsum([q, k], reduced_dims=[self.key_dim]) logits = self.talking_heads( context, p, "logits", self.key_heads_dims, self.softmax_heads_dims, dynamic_projections_from=( ([x] if "x2l" in self.dynamic_projections else []) + ([m] if "m2l" in self.dynamic_projections else []))) if bias is not None: logits += bias h = mtf.softmax(logits, memory_length) weights = self.talking_heads( context, h, "weights", self.softmax_heads_dims, self.value_heads_dims, dynamic_projections_from=( ([x] if "x2w" in self.dynamic_projections else []) + ([m] if "m2w" in self.dynamic_projections else []))) # TODO(noam): make dropout_broadcast_dims configurable dropout_broadcast_dims = [context.length_dim] weights = mtf.dropout(weights, rate=self.dropout_rate if context.train else 0.0, noise_shape=weights.shape - dropout_broadcast_dims) u = mtf.einsum([weights, v], reduced_dims=[memory_length]) return self.compute_y(context, u)
def attention(q, k, v, memory_length_dim, key_dim, value_dim, bias=None, dropout_rate=0.0, dropout_broadcast_dims=None, extra_logit=None, context=None): """Dot-product attention - doesn't use positional dimensions. key_dim is a Dimension representing the channels in the queries and keys value_dim is a Dimension representing the channels in values memory_length_dim is a Dimension representing the different key/value pairs. Dimensions of q: other_query_dims + {key_dim} Dimensions of k: other_memory_dims + {memory_length_dim, key_dim} Dimensions of v: other_memory_dims + {memory_length_dim, value_dim} other_memory_dims is a subset of other_query_dims Typically, other_query_dims={batch, heads, length} Typically, other_memory_dims={batch, heads} Args: q: a Tensor k: a Tensor v: a Tensor memory_length_dim: a Dimension key_dim: a Dimension value_dim: a Dimension bias: a Tensor to be added into the attention logits. dropout_rate: a float. dropout_broadcast_dims: an optional list of mtf.Dimension extra_logit: an optional scalar or tensor context: an optional Transformer.Context Returns: Tensor with shape q.shape - key_dim + value_dim """ orig_q_shape = q.shape q, k, v, bias = maybe_reshape_attention_input_for_2d_sharding( context, q, k, v, bias, [key_dim, value_dim]) logits = mtf.layers.us_einsum([q, k], reduced_dims=[key_dim]) if bias is not None: logits += bias weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit) if dropout_rate != 0.0: weights = mtf.dropout(weights, 1.0 - dropout_rate, noise_shape=weights.shape - dropout_broadcast_dims) outputs_shape = q.shape - key_dim + value_dim outputs = mtf.einsum([weights, v], outputs_shape) outputs = mtf.reshape(outputs, orig_q_shape - key_dim + value_dim) return outputs
def attention_internal(self, context, q, m, memory_length, bias): logits = mtf.einsum([q, m], reduced_dims=[context.model.model_dim]) if bias is not None: logits += bias weights = mtf.softmax(logits, memory_length) # TODO(noam): make dropout_broadcast_dims configurable dropout_broadcast_dims = [context.length_dim] weights = mtf.dropout( weights, rate=self.dropout_rate if context.train else 0.0, noise_shape=weights.shape - dropout_broadcast_dims) u = mtf.einsum([weights, m], reduced_dims=[memory_length]) return self.compute_y(context, u)
def _get_decoder_inputs(self, context): """Computes the inputs to the decoder when using transparent attention. We must cache on the context in order to ensure that we are not replicating variables when the layer's call function is called in different tf variable scopes. Args: context: a Context Returns: a list containing `self.num_decoder_modules` of tensors with shape [<batch_dims>, length_dim, output_vocab_dim] """ if hasattr(context, "decoder_layers_per_module"): return context.decoder_layers_per_module encoder_layer_outputs = [ mtf.layers.rename_length_to_memory_length(output) for output in context.encoder_layer_outputs ] layers_per_module = self.layers_per_encoder_module encoder_module_outputs_dim = mtf.Dimension( "encoder_module_outputs", size=self.encoder_num_modules + 1) decoder_module_inputs_dim = mtf.Dimension( "decoder_module_inputs", size=self.decoder_num_modules) encoder_module_outputs = mtf.stack( [encoder_layer_outputs[0]] + encoder_layer_outputs[layers_per_module::layers_per_module], dim_name="encoder_module_outputs") w = mtf.get_variable( context.mesh, "w", mtf.Shape([encoder_module_outputs_dim, decoder_module_inputs_dim]), initializer=tf.random_normal_initializer( stddev=(encoder_module_outputs_dim.size * decoder_module_inputs_dim.size)**-0.5), dtype=context.variable_dtype) if context.train and self.dropout_rate != 0.0: w = mtf.dropout(w, 1.0 - self.dropout_rate) s = mtf.softmax(w, reduced_dim=encoder_module_outputs_dim) z = mtf.einsum([s, encoder_module_outputs], reduced_dims=[encoder_module_outputs_dim]) input_per_decoder = mtf.split( z, split_dim=decoder_module_inputs_dim, num_or_size_splits=decoder_module_inputs_dim.size) context.decoder_layers_per_module = [ mtf.reshape(inpt, z.shape.dims[1:]) for inpt in input_per_decoder ] return context.decoder_layers_per_module
def attention(q, k, v, memory_length_dim, key_dim, value_dim, mask=None, dropout_rate=0.0, dropout_broadcast_dims=None, extra_logit=None): """Dot-product attention - doesn't use positional dimensions. key_dim is a Dimension representing the channels in the queries and keys value_dim is a Dimension representing the channels in values memory_length_dim is a Dimension representing the different key/value pairs. Dimensions of q: other_query_dims + {key_dim} Dimensions of k: other_memory_dims + {memory_length_dim, key_dim} Dimensions of v: other_memory_dims + {memory_length_dim, value_dim} other_memory_dims is a subset of other_query_dims Typically, other_query_dims={batch, heads, length} Typically, other_memory_dims={batch, heads} Args: q: a Tensor k: a Tensor v: a Tensor memory_length_dim: a Dimension key_dim: a Dimension value_dim: a Dimension mask: mask Tensor (see attention_mask()) dropout_rate: a float. dropout_broadcast_dims: an optional list of mtf.Dimension extra_logit: an optional scalar or tensor Returns: Tensor with shape q.shape - key_dim + value_dim """ logits = mtf.einsum([q, k], reduced_dims=[key_dim]) if mask is not None: logits += mask weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit) if dropout_rate != 0.0: weights = mtf.dropout( weights, 1.0 - dropout_rate, noise_shape=weights.shape - dropout_broadcast_dims) outputs_shape = q.shape - key_dim + value_dim outputs = mtf.einsum([weights, v], outputs_shape) return outputs
def create_model(bert_config, is_training, input_ids, input_mask, segment_ids, labels, num_labels_dim, layout, mesh_shape): """Creates a classification model.""" model = bert_lib.BertModel(config=bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, token_type_ids=segment_ids, layout=layout, mesh_shape=mesh_shape) # In the demo, we are doing a simple classification task on the entire # segment. # # If you want to use the token-level output, use model.get_sequence_output() # instead. output_layer = model.get_pooled_output() hidden_dim = output_layer.shape[-1] mesh = input_ids.mesh output_weights = mtf.get_variable( mesh, "output_weights", shape=[num_labels_dim, hidden_dim], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = mtf.get_variable(mesh, "output_bias", shape=[num_labels_dim], initializer=tf.zeros_initializer()) with tf.variable_scope("loss"): if is_training: # I.e., 0.1 dropout output_layer = mtf.dropout(output_layer, keep_prob=0.9) logits = mtf.einsum([output_layer, output_weights], reduced_dims=[hidden_dim]) logits = logits + output_bias probabilities = mtf.softmax(logits, reduced_dim=num_labels_dim) per_example_loss = mtf.layers.softmax_cross_entropy_with_logits( logits, labels, vocab_dim=num_labels_dim) loss = mtf.reduce_mean(per_example_loss) + model.get_extra_loss() return (loss, per_example_loss, logits, probabilities)
def causal_linear_attention(q, k, v, epsilon=1e-6): batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3]) q = mtf.rename_dimension(q, "features_per_head", "features_per_head_in") k = mtf.rename_dimension(k, "features_per_head", "features_per_head_in") dim_in = k.shape[-1] q = mtf.softmax(q, dim_in) k = mtf.exp(k) cumulative_k = mtf.cumsum(k, seq_dim) context = mtf.einsum([k, v], output_shape=[batch_dim, seq_dim, head_dim, dim_in, dim_out]) cumulative_context = mtf.cumsum(context, seq_dim) cumulative_context /= (cumulative_k + epsilon) attn = mtf.einsum([q, cumulative_context], output_shape=[batch_dim, seq_dim, head_dim, dim_out]) return attn
def call(self, context, x: mtf.Tensor) -> mtf.Tensor: """Call the layer.""" # Initialize Memory Keys and Values n_key_dim = mtf.Dimension("n_keys", self.n_keys) n_value_dim = mtf.Dimension("n_values", self.n_values) key_dim = mtf.Dimension("key", self.key_size // 2) value_dim = x.shape.dims[-1] head_dim = mtf.Dimension("n_heads", self.n_heads) product_dim = mtf.Dimension("product_key", 2) keys = mtf.get_variable( context.mesh, name="keys", shape=mtf.Shape([head_dim, product_dim, n_key_dim, key_dim]), dtype=context.variable_dtype) values = mtf.layers.embedding_weights( context.mesh, vocab_dim=n_value_dim, output_dim=value_dim, variable_dtype=context.variable_dtype, name="values") # Compute query new_dims = [head_dim, product_dim, key_dim] reduce_dims = x.shape.dims[-1:] query = mtf.layers.dense(x, new_dims, reduced_dims=reduce_dims, activation=None, use_bias=True, variable_dtype=context.variable_dtype, name="query") # [b, l, h, 2, k] # Note: We use layer norm instead of batch norm to normalize queries. # The main advantage is that layer norm works well with the codebase # whereas the implementation of batch norm requires handling of tf ops. query = mtf.layers.layer_norm(query, query.shape.dims[-1]) # Retrieve indices and scores scores, indices = self.get_indices(keys, query) # [b, l, h, k] scores = mtf.softmax(scores, reduced_dim=scores.shape.dims[-1]) top_values = mtf.gather(values, indices, n_value_dim) # [b, l, h, k, v] out_values = mtf.einsum( [top_values, scores], reduced_dims=scores.shape.dims[-2:]) # [b, l, v] return out_values
def attention(x, dim_head, dim_features_head, scope='attn', causal=False): with tf.variable_scope(scope): mesh, batch, seq, dim = x.mesh, *x.shape dim_heads = mtf.Dimension('dim_heads', dim_head.size * dim_features_head.size) dim_intermediate = mtf.Dimension('qkv_dimension', dim_heads.size * 3) qkv = linear(x, dim_intermediate, bias=False, scope='to_qkv') q, k, v = mtf.split(qkv, dim_intermediate, 3) q, k, v = map( lambda t: mtf.reshape(t, [batch, seq, dim_head, dim_features_head] ), (q, k, v)) q, k, v = map( lambda t: mtf.transpose( t, [batch, dim_head, seq, dim_features_head]), (q, k, v)) k, v = map( lambda t: mtf.rename_dimension(t, seq.name, 'memory_length'), (k, v)) mem_len_dim = v.shape[-2] dots = mtf.layers.us_einsum([q, k], [batch, dim_head, seq, mem_len_dim]) if causal: i = mtf.range(mesh, seq, tf.int32) j = mtf.range(mesh, mem_len_dim, tf.int32) i, j = map(lambda t: mtf.broadcast(t, [seq, mem_len_dim]), (i, j)) mask = mtf.less(i + mem_len_dim.size - seq.size, j) mask = mtf.cast(mask, tf.float32) * -1e10 dots += mask attn = mtf.softmax(dots, mem_len_dim) out = mtf.einsum([attn, v], [batch, dim_head, seq, dim_features_head]) out = mtf.transpose(out, [batch, seq, dim_head, dim_features_head]) out = mtf.reshape(out, [batch, seq, dim_heads]) combined_out = linear(out, dim, scope='combine_output') return combined_out
def hybrid_attention(q, k, v, context, memory_length_dim, key_dim, value_dim, bias=None, dropout_rate=0.0, dropout_broadcast_dims=None, extra_logit=None): """Dot-product attention - doesn't use positional dimensions. key_dim is a Dimension representing the channels in the queries and keys value_dim is a Dimension representing the channels in values memory_length_dim is a Dimension representing the different key/value pairs. Dimensions of q: other_query_dims + {key_dim} Dimensions of k: other_memory_dims + {memory_length_dim, key_dim} Dimensions of v: other_memory_dims + {memory_length_dim, value_dim} other_memory_dims is a subset of other_query_dims Typically, other_query_dims={batch, heads, length} Typically, other_memory_dims={batch, heads} Args: q: a Tensor k: a Tensor v: a Tensor context: context of the attention layer. memory_length_dim: a Dimension key_dim: a Dimension value_dim: a Dimension bias: a Tensor to be added into the attention logits. dropout_rate: a float. dropout_broadcast_dims: an optional list of mtf.Dimension extra_logit: an optional scalar or tensor Returns: Tensor with shape q.shape - key_dim + value_dim """ logits = mtf.layers.us_einsum([q, k], reduced_dims=[key_dim]) if bias is not None: logits += bias query_length_dim = mtf.Dimension("length", memory_length_dim.size) doubly_coeff = mtf.get_variable( context.mesh, "doubly_coeff", [], initializer=tf.constant_initializer(0.5), dtype=context.variable_dtype) doubly_coeff = mtf.maximum(mtf.minimum(doubly_coeff, 1.), 0.) upper_weights = mtf.softmax( logits, memory_length_dim, extra_logit=extra_logit) lower_log_weights = mtf.log_softmax( logits, query_length_dim, extra_logit=extra_logit) doubly_weights = mtf.softmax( lower_log_weights, memory_length_dim, extra_logit=extra_logit) weights = doubly_coeff * doubly_weights + (1. - doubly_coeff) * upper_weights weights = mtf.dropout( weights, context.train, 1.0 - dropout_rate, noise_shape=weights.shape - dropout_broadcast_dims) outputs_shape = q.shape - key_dim + value_dim outputs = mtf.einsum([weights, v], outputs_shape) return outputs
def synthetic_attention(q, k, v, memory_length_dim, key_dim, value_dim, bias=None, dropout_rate=0.0, dropout_broadcast_dims=None, extra_logit=None, synthesize=True, synthesize_mode="random_plus_alpha", factorized_dim=16, max_length=512, context=None): """Synthetic Attention from Synthesizers (https://arxiv.org/abs/2005.00743). key_dim is a Dimension representing the channels in the queries and keys value_dim is a Dimension representing the channels in values memory_length_dim is a Dimension representing the different key/value pairs. Dimensions of q: other_query_dims + {key_dim} Dimensions of k: other_memory_dims + {memory_length_dim, key_dim} Dimensions of v: other_memory_dims + {memory_length_dim, value_dim} other_memory_dims is a subset of other_query_dims Typically, other_query_dims={batch, heads, length} Typically, other_memory_dims={batch, heads} Args: q: a Tensor k: a Tensor v: a Tensor memory_length_dim: a Dimension key_dim: a Dimension value_dim: a Dimension bias: a Tensor to be added into the attention logits. dropout_rate: a float. dropout_broadcast_dims: an optional list of mtf.Dimension extra_logit: an optional scalar or tensor synthesize: flag to use synthetic attention or not synthesize_mode: which variant of synthesizer to use factorized_dim: factorized dim for synthesizers max_length: max length of input sequence context: context since we need context mode Returns: Tensor with shape q.shape - key_dim + value_dim """ if synthesize: num_heads = v.shape.get_dim_by_name("heads") tf.logging.info("Using synthesizer") if synthesize_mode == "random": tf.logging.info("Using Random Synthesizers") r_shape = mtf.Shape([mtf.Dimension("length", max_length), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", max_length)]) r = mtf.get_variable(context.mesh, "R", r_shape, initializer=None, dtype=context.variable_dtype) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length")) else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, "length") logits = r r_shape = logits.shape elif synthesize_mode == "factorized": tf.logging.info("Using Factorized Random Synthesizers") k = factorized_dim r1_shape = mtf.Shape([mtf.Dimension("tmp", k), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r2_shape = mtf.Shape([mtf.Dimension("tmp", k), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r_shape = mtf.Shape([mtf.Dimension("length", 512), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r1 = mtf.get_variable(context.mesh, "R1", r1_shape, initializer=None, dtype=context.variable_dtype) r2 = mtf.get_variable(context.mesh, "R2", r2_shape, initializer=None, dtype=context.variable_dtype) r = mtf.einsum([r1, r2], r_shape) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length")) else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, "length") logits = r elif synthesize_mode == "dense_minus": # Dense Synthesizer Model tmp_dim = mtf.Dimension("memory_length", max_length) logits = mtf.layers.dense(mtf.relu(q), [tmp_dim], use_bias=False, name="pi", reduced_dims=[key_dim], variable_dtype=None) logits = mtf.slice(logits, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": pass else: length_dim = q.shape.get_dim_by_name("length") logits = mtf.slice(logits, 0, length_dim.size, "length") elif synthesize_mode == "random_plus_alpha" or \ synthesize_mode == "random_plus": # Mixture Random Synthesizer with learnable Alpha tf.logging.info("Using Random Plus Alpha") logits = mtf.einsum([q, k], reduced_dims=[key_dim]) num_heads = logits.shape.get_dim_by_name("heads") r_shape = mtf.Shape([mtf.Dimension("length", 512), mtf.Dimension("heads", num_heads.size), mtf.Dimension("memory_length", 512)]) r = mtf.get_variable(context.mesh, "R", r_shape, initializer=None, dtype=context.variable_dtype) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": r = mtf.gather(r, context.position, r.shape.get_dim_by_name("length")) else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, length_dim.name) if "alpha" in synthesize_mode: alpha = mtf.get_variable(context.mesh, "alpha", mtf.Shape([mtf.Dimension("alpha", 1)]), initializer=tf.zeros_initializer(), dtype=context.variable_dtype) alpha = mtf.sigmoid(alpha) logits = ((1-alpha) * logits) + (alpha * r) else: logits = logits + r elif synthesize_mode == "dense_plus_alpha" or \ synthesize_mode == "dense_plus": # Mixture Dense Synthesizer with learnable alpha tf.logging.info("Using Dense Plus Alpha Scaling") logits = mtf.einsum([q, k], reduced_dims=[key_dim]) tmp_dim = mtf.Dimension("memory_length", 512) r = mtf.layers.dense(mtf.relu(q), [tmp_dim], use_bias=False, name="pi", reduced_dims=[key_dim], variable_dtype=None) r = mtf.slice(r, 0, memory_length_dim.size, memory_length_dim.name) if context.mode == "incremental": pass else: length_dim = q.shape.get_dim_by_name("length") r = mtf.slice(r, 0, length_dim.size, "length") if "alpha" in synthesize_mode: alpha = mtf.get_variable(context.mesh, "alpha", mtf.Shape([mtf.Dimension("alpha", 1)]), initializer=tf.zeros_initializer(), dtype=context.variable_dtype) alpha = mtf.sigmoid(alpha) logits = ((1-alpha) * logits) + (alpha * r) else: logits = logits + r if bias is not None: logits += bias weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit) weights = mtf.dropout( weights, context.train, 1.0 - dropout_rate, noise_shape=weights.shape - dropout_broadcast_dims) if synthesize and "plus" not in synthesize_mode: if synthesize_mode == "dense_minus": outputs_shape = mtf.Shape(q.shape.dims[:-1] + [value_dim]) else: outputs_shape = mtf.Shape(q.shape.dims[:-1] + [num_heads, value_dim]) else: outputs_shape = q.shape - [key_dim] + value_dim outputs = mtf.einsum([weights, v], outputs_shape) return outputs
def unet_with_spatial_partition(mesh, mesh_impl, dataset_str, images, labels): """Builds the UNet model graph, train op and eval metrics. Args: mesh: a MeshTensorflow.mesh object. mesh_impl: a mesh implementation, such as SimdMeshImpl and PlacementMeshImpl. dataset_str: a string of either train or eval. This is used for batch_norm. images: a laid out Tensor with shape [batch, x, y, num_channels] or [batch, x, y, z, num_channels]. labels: a laid out Tensor with shape [batch, x, y, num_classes] or [batch, x, y, z, num_classes]. Returns: Prediction and loss. """ is_training = (dataset_str == 'train') if dataset_str == 'train': batch_dim = mtf.Dimension('batch', FLAGS.batch_size_train) else: assert dataset_str == 'eval' batch_dim = mtf.Dimension('batch', FLAGS.batch_size_eval) image_nx_dim = mtf.Dimension('image_nx_block', FLAGS.image_nx_block) image_ny_dim = mtf.Dimension('image_ny_block', FLAGS.image_ny_block) image_sx_dim = mtf.Dimension('image_sx_block', FLAGS.ct_resolution // FLAGS.image_nx_block) image_sy_dim = mtf.Dimension('image_sy_block', FLAGS.ct_resolution // FLAGS.image_ny_block) image_sz_dim = mtf.Dimension('image_sz_block', FLAGS.ct_resolution) image_c_dim = mtf.Dimension('image_c', FLAGS.image_c) label_c_dim = mtf.Dimension('label_c', FLAGS.label_c) mtf_images_shape, mtf_labels_shape = get_input_mtf_shapes(dataset_str) mtf_dtype = tf.as_dtype(FLAGS.mtf_dtype) variable_dtype = mtf.VariableDType(mtf_dtype, mtf_dtype, mtf_dtype) # Import input features. x = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(images), mtf_images_shape) x = mtf.cast(x, mtf_dtype) # Import ground truth labels. t = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(labels), mtf_labels_shape) t = mtf.cast(t, mtf_dtype) # Transpose the blocks. if FLAGS.sampled_2d_slices: x = mtf.transpose(x, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, image_c_dim ]) t = mtf.transpose(t, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, label_c_dim ]) else: x = mtf.transpose(x, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, image_sz_dim, image_c_dim ]) t = mtf.transpose(t, [ batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim, image_sz_dim, label_c_dim ]) # Network. levels = [] all_bn_update_ops = [] # add levels with convolution or down-sampling for depth in range(FLAGS.network_depth): for n_conv in range(FLAGS.n_conv_per_block): if depth == 0 and n_conv == 0: # no dropout in 1st layer. dropout_keep_p = 1.0 else: dropout_keep_p = FLAGS.dropout_keep_p x, bn_update_ops = conv_with_spatial_partition( x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim, FLAGS.n_base_filters * (2**depth), dropout_keep_p, FLAGS.with_batch_norm, is_training, 'conv_{}_{}'.format(depth, n_conv), variable_dtype, 'conv_down_{}_{}'.format(depth, n_conv)) all_bn_update_ops.extend(bn_update_ops) levels.append(x) if depth < FLAGS.network_depth - 1: if FLAGS.sampled_2d_slices: x = mtf.layers.max_pool2d(x, ksize=(2, 2)) else: x = mtf.layers.max_pool3d(x, ksize=(2, 2, 2)) # add levels with up-convolution or up-sampling for depth in range(FLAGS.network_depth - 1)[::-1]: x = deconv_with_spatial_partition( x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim, FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p, 'conv_{}_{}'.format(depth, FLAGS.n_conv_per_block - 1), variable_dtype, 'deconv_{}_0'.format(depth)) x = mtf.concat([x, levels[depth]], concat_dim_name='conv_{}_{}'.format( depth, FLAGS.n_conv_per_block - 1)) for n_conv in range(FLAGS.n_conv_per_block): x, bn_update_ops = conv_with_spatial_partition( x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim, FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p, FLAGS.with_batch_norm, is_training, 'conv_{}_{}'.format(depth, n_conv), variable_dtype, 'conv_up_{}_{}'.format(depth, n_conv)) all_bn_update_ops.extend(bn_update_ops) # no dropout in the final layer. if FLAGS.sampled_2d_slices: y = mtf.layers.conv2d_with_blocks( x, mtf.Dimension('label_c', FLAGS.label_c), filter_size=(1, 1), strides=(1, 1), padding='SAME', h_blocks_dim=image_nx_dim, w_blocks_dim=image_ny_dim, variable_dtype=variable_dtype, name='final_conv_{}'.format(FLAGS.label_c), ) else: y = mtf.layers.conv3d_with_blocks( x, mtf.Dimension('label_c', FLAGS.label_c), filter_size=(1, 1, 1), strides=(1, 1, 1), padding='SAME', d_blocks_dim=image_nx_dim, h_blocks_dim=image_ny_dim, variable_dtype=variable_dtype, name='final_conv_{}'.format(FLAGS.label_c), ) # use mtf.constant to make sure there is no CPU-side constants. def scalar(v, dtype): return mtf.constant(mesh, v, shape=[], dtype=dtype) argmax_t = mtf.argmax(t, label_c_dim) liver_t = mtf.cast(mtf.equal(argmax_t, scalar(1, tf.int32)), mtf_dtype) lesion_t = mtf.cast(mtf.equal(argmax_t, scalar(2, tf.int32)), mtf_dtype) argmax_y = mtf.argmax(y, label_c_dim) lesion_y = mtf.cast(mtf.equal(argmax_y, scalar(2, tf.int32)), mtf_dtype) # summary of class ratios. lesion_pred_ratio = mtf.reduce_mean(lesion_y) lesion_label_ratio = mtf.reduce_mean(lesion_t) # summary of accuracy. accuracy = mtf.reduce_mean( mtf.cast(mtf.equal(argmax_y, argmax_t), mtf_dtype)) # Cross-entropy loss. Up-weight the liver region. pixel_loss = mtf.layers.softmax_cross_entropy_with_logits( y, t, label_c_dim) pixel_weight = scalar(1, mtf_dtype) + \ liver_t * scalar(FLAGS.xen_liver_weight - 1, mtf_dtype) + \ lesion_t * scalar(FLAGS.xen_lesion_weight - FLAGS.xen_liver_weight, mtf_dtype) loss_xen = mtf.reduce_mean(pixel_loss * pixel_weight) # Dice loss y_prob = mtf.softmax(y, reduced_dim=label_c_dim) lesion_prob = mtf.reduce_sum(mtf.slice(y_prob, 2, 1, 'label_c'), reduced_dim=mtf.Dimension('label_c', 1)) prob_intersect = mtf.reduce_sum(lesion_prob * lesion_t, output_shape=mtf.Shape([batch_dim])) prob_area_sum = mtf.reduce_sum(lesion_prob + lesion_t, output_shape=mtf.Shape([batch_dim])) loss_dice_per_case = mtf.reduce_mean( scalar(-2, mtf_dtype) * prob_intersect / (prob_area_sum + scalar(FLAGS.dice_epsilon, mtf_dtype))) loss_dice_global = scalar(-2, mtf_dtype) * mtf.reduce_sum( prob_intersect) / (mtf.reduce_sum(prob_area_sum) + scalar(FLAGS.dice_epsilon, mtf_dtype)) loss_dice = (loss_dice_per_case + loss_dice_global) * scalar( 0.5, mtf_dtype) loss = scalar(FLAGS.dice_loss_weight, mtf_dtype) * loss_dice + scalar( 1 - FLAGS.dice_loss_weight, mtf_dtype) * loss_xen intersect = mtf.reduce_sum(lesion_y * lesion_t, output_shape=mtf.Shape([batch_dim])) area_sum = mtf.reduce_sum(lesion_y + lesion_t, output_shape=mtf.Shape([batch_dim])) # summary of dice. dice_per_case = mtf.reduce_mean( scalar(2, mtf_dtype) * intersect / (area_sum + scalar(0.000001, mtf_dtype))) dice_global = scalar(2, mtf_dtype) * mtf.reduce_sum(intersect) / ( mtf.reduce_sum(area_sum) + scalar(0.000001, mtf_dtype)) eval_metrics = { 'lesion_pred_ratio': lesion_pred_ratio, 'lesion_label_ratio': lesion_label_ratio, 'accuracy_of_all_classes': accuracy, 'lesion_dice_per_case': dice_per_case, 'lesion_dice_global': dice_global, 'loss_xen': loss_xen, 'loss_dice': loss_dice, 'loss_dice_per_case': loss_dice_per_case, 'loss_dice_global': loss_dice_global, } if FLAGS.sampled_2d_slices: y_prob_downsampled = mtf.layers.avg_pool2d( y_prob, ksize=(FLAGS.pred_downsample, ) * 2) if FLAGS.output_ground_truth: lesion_gt_downsampled = mtf.layers.avg_pool2d( mtf.slice(t, 2, 1, 'label_c'), ksize=(FLAGS.pred_downsample, ) * 2) else: y_prob_downsampled = mtf.layers.avg_pool3d( y_prob, ksize=(FLAGS.pred_downsample, ) * 3) if FLAGS.output_ground_truth: lesion_gt_downsampled = mtf.layers.avg_pool3d( mtf.slice(t, 2, 1, 'label_c'), ksize=(FLAGS.pred_downsample, ) * 3) liver_prob_downsampled = mtf.slice(y_prob_downsampled, 1, 1, 'label_c') lesion_prob_downsampled = mtf.slice(y_prob_downsampled, 2, 1, 'label_c') preds = [ mtf.reduce_sum(liver_prob_downsampled, reduced_dim=mtf.Dimension('label_c', 1)), mtf.reduce_sum(lesion_prob_downsampled, reduced_dim=mtf.Dimension('label_c', 1)) ] if FLAGS.output_ground_truth: preds.append( mtf.reduce_sum(lesion_gt_downsampled, reduced_dim=mtf.Dimension('label_c', 1))) preds.extend([intersect, area_sum]) return preds, loss, eval_metrics, all_bn_update_ops
def _switch_gating(inputs, outer_expert_dims, experts_dim, expert_capacity_dim, hparams, train, variable_dtype, importance=None, name="switch_gating", num_microbatches=None): """Compute a switch top-1 gating with no-token-left behind behavior.""" # SELECT EXPERT if train: policy = hparams.moe_rand_1_policy_train else: policy = hparams.moe_rand_1_policy_eval # Input perturbations if train and policy == "input_jitter": inputs = mtf.layers.multiplicative_jitter(inputs, hparams.moe_rand_1_jitter) gate_logits = mtf.layers.dense( inputs, experts_dim, use_bias=False, expert_dims=outer_expert_dims, variable_dtype=variable_dtype, name=name) raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim) # The internals of this function run in float32. # bfloat16 seems to reduce quality. raw_gates = mtf.to_float(raw_gates) # Top-k operation k_dim = mtf.Dimension("k", hparams.moe_switch_top_k) expert_gate, expert_index = mtf.top_k( raw_gates, reduced_dim=experts_dim, k_dim=k_dim) expert_mask = mtf.one_hot(expert_index, experts_dim) # LOAD BALANCING LOSS outer_batch_dim = inputs.shape[0] batch_dim = inputs.shape[1] group_size_dim = inputs.shape[-2] density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim) density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim) if importance is not None: expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype) expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype) density_1_proxy *= mtf.cast( mtf.equal(importance, 1.0), dtype=raw_gates.dtype) loss = ( mtf.reduce_mean(density_1_proxy * density_1) * float(experts_dim.size * experts_dim.size)) if num_microbatches and num_microbatches > 1: tf.logging.info("Dividing load-balance loss by num_microbatches={}".format( num_microbatches)) loss /= num_microbatches # Logging if train: entropy = mtf.reduce_sum( -raw_gates * mtf.log(raw_gates + 1e-9), reduced_dim=experts_dim) batch_entropy = mtf.reduce_mean(entropy) mtf.scalar_summary(name + "/entropy", batch_entropy) mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim]) total_routed = mtf.reduce_sum(mask_count_experts) expert_fraction = mtf.to_float(mask_count_experts / total_routed) split_fractions = mtf.split( expert_fraction, split_dim=experts_dim, num_or_size_splits=experts_dim.size) for fraction in split_fractions: mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"), mtf.reduce_mean(fraction)) mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss)) # COMPUTE ASSIGNMENT TO EXPERT # Iteratively route tokens (no-token-left-behind). The idea is to route as # many tokens as possible to top-i before then trying top-(i+1). top_k_masks = mtf.split( expert_mask, split_dim=k_dim, num_or_size_splits=k_dim.size) top_k_gates = mtf.split( expert_gate, split_dim=k_dim, num_or_size_splits=k_dim.size) top_k_indices = mtf.split( expert_index, split_dim=k_dim, num_or_size_splits=k_dim.size) # Tensors cumulative values over the iterative process. combine_tensor = mtf.constant( inputs.mesh, value=0, shape=[outer_batch_dim, batch_dim, experts_dim, expert_capacity_dim]) cum_tokens = mtf.constant( inputs.mesh, value=0, shape=[outer_batch_dim, batch_dim, experts_dim]) tokens_left_to_route = mtf.constant( inputs.mesh, value=1., shape=[outer_batch_dim, batch_dim, group_size_dim]) expert_capacity_float = float(expert_capacity_dim.size) for (top_i_mask, top_i_gate, top_i_index) in zip(top_k_masks, top_k_gates, top_k_indices): top_i_mask = mtf.reshape( top_i_mask, new_shape=[outer_batch_dim, batch_dim, group_size_dim, experts_dim]) # Operate only on the unrouted tokens. top_i_mask *= tokens_left_to_route # Record cumulative number of tokens to each expert across iterations. cumulative_tokens_in_expert = cum_tokens + mtf.cumsum( top_i_mask, group_size_dim) expert_overflow = mtf.to_float( mtf.less_equal(cumulative_tokens_in_expert, expert_capacity_float)) output_i_tokens = top_i_mask * expert_overflow # Update the cumulative tokens routed to each expert. cum_tokens += mtf.reduce_sum(output_i_tokens, reduced_dim=group_size_dim) tokens_left_to_route -= ( mtf.reduce_sum(output_i_tokens, reduced_dim=experts_dim)) # Combine-tensor for this iteration output_i_tokens_flat = mtf.reduce_sum( output_i_tokens, reduced_dim=experts_dim) position_in_expert = cumulative_tokens_in_expert - 1 top_i_combine_tensor = ( top_i_gate * output_i_tokens_flat * mtf.one_hot(top_i_index, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert), expert_capacity_dim)) combine_tensor += top_i_combine_tensor # Match the inputs dtype. combine_tensor = mtf.cast(combine_tensor, inputs.dtype) loss = mtf.cast(loss, inputs.dtype) dispatch_tensor = mtf.cast( mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype) return dispatch_tensor, combine_tensor, loss
def gradient_based_subword_tokenization(x, length_dim, max_subword_length=4, downsample=None, use_offsets=False, consider_chars_as_blocks=False, use_block_pos_embedding=False, share_block_kernel=False, memory_embeddings=0, context=None, block_mixing_mode=None, activation="softmax", downsample_function="mean"): """Implements GBSWT from Charformer. Args: x: a Tensor containing length_dim length_dim: a Dimension max_subword_length: integer downsample: integer. use_offsets: boolean. consider_chars_as_blocks: boolean. use_block_pos_embedding: boolean. share_block_kernel: boolean. memory_embeddings: integer. context: Context. block_mixing_mode: Str for block mixing. activation: Str for block ranking. downsample_function: Str, supports mean/linformer for now. Returns: a Tensor with the same shape as x. Raises: ValueError: if channels or depth don't match. """ # don't use this for now. del max_subword_length del memory_embeddings all_blocks = [] all_scores = [] tf.logging.info("GSW block layer") def _tile(x, n, tile_dim): # Simple tile function in MTF. return mtf.concat([x] * n, tile_dim.name) def _repeat(x, n, repeat_dim): # repeat function in MTF tmp_dim = mtf.Dimension("tmp", 1) expand_shape = mtf.Shape(x.shape.dims + [tmp_dim]) x = mtf.reshape(x, expand_shape) x = _tile(x, n, tmp_dim) output_shape = [] for dim in x.shape.dims: if dim.name == "tmp": continue if dim.name == repeat_dim.name: dim = mtf.Dimension(dim.name, dim.size * n) output_shape.append(dim) output_shape = mtf.Shape(output_shape) x = mtf.reshape(x, output_shape) return x def _combined_dim(dims): return mtf.Dimension(dims[0].name, mtf.Shape(dims).size) # compute all subword blocks # TODO(yitay): handle offsets to get all blocks if activation == "sigtanh": # one score for sigmoid tmp_dim = mtf.Dimension("block_score", 2) else: tmp_dim = mtf.Dimension("block_score", 1) model_dim = x.shape[-1] subword_blocks_width = [2, 3, 4] if consider_chars_as_blocks: subword_blocks_width += [1] if share_block_kernel: block_kernel_shape = mtf.Shape([model_dim, tmp_dim]) block_kernel = mtf.get_variable(x.mesh, "block_kernel", block_kernel_shape, initializer=None, dtype=context.variable_dtype) else: block_kernel = None for subword_len in subword_blocks_width: if use_block_pos_embedding: # this is turn off by default. It is meant to support cases like # parameterized pooling or other features. block_len_dim = mtf.Dimension(length_dim.name, subword_len) # TODO(vqtran): Consider other positional embeddings. block_pos_emb = sinusoid_positional_embedding_weights( context.mesh, block_len_dim, x.shape[-1], context.variable_dtype.activation_dtype) block_pos_emb = _repeat( block_pos_emb, math.ceil(length_dim.size / float(subword_len)), block_len_dim) if use_offsets: offset_space = subword_len else: offset_space = 1 for offsets in range(offset_space): if offsets > 0: xoff = mtf.shift(x, offsets, length_dim, wrap=False) if use_block_pos_embedding: block_pos_emb = mtf.shift(block_pos_emb, offsets, block_pos_emb.shape[-2], wrap=False) else: xoff = x tf.logging.info("SW len=%d offset=%d", subword_len, offsets) if length_dim.size % subword_len != 0: tf.logging.info("Not divisible by length") # add extra padding tokens pad_amt = int(subword_len) - int(length_dim.size % subword_len) kp = mtf.pad(xoff, [0, pad_amt], length_dim.name) else: kp = xoff if use_block_pos_embedding: kp += block_pos_emb bx = mtf.pool_tensor_1d( kp, pool_dim=kp.shape.get_dim_by_name("length"), reduce_fn=mtf.reduce_mean, pool_size=int(subword_len)) block_score = mtf.layers.dense(bx, [tmp_dim], use_bias=False, name="bx", reduced_dims=[model_dim], variable_dtype=None, kernel_weights=block_kernel) expand_bx = _repeat(bx, subword_len, length_dim) expand_scores = _repeat(block_score, subword_len, length_dim) if offsets > 0: # add offset. expand_bx = mtf.pad(expand_bx, [offsets, 0], length_dim.name) expand_scores = mtf.pad(expand_scores, [offsets, 0], length_dim.name) new_len = expand_bx.shape.get_dim_by_name(length_dim.name) if new_len.size < length_dim.size: pad_amt = new_len.size - length_dim.size expand_bx = mtf.pad(expand_bx, [0, pad_amt], length_dim.name) expand_scores = mtf.pad(expand_scores, [0, pad_amt], length_dim.name) elif new_len.size > length_dim.size: expand_bx = mtf.slice(expand_bx, 0, length_dim.size, length_dim.name) expand_scores = mtf.slice(expand_scores, 0, length_dim.size, length_dim.name) new_tmp_dim = mtf.Dimension("extra_dim", 1) expand_shape = mtf.Shape(expand_bx.shape.dims + [new_tmp_dim]) expand_scores_shape = mtf.Shape(expand_scores.shape.dims + [new_tmp_dim]) expand_bx = mtf.reshape(expand_bx, expand_shape) expand_scores = mtf.reshape(expand_scores, expand_scores_shape) all_blocks.append(expand_bx) all_scores.append(expand_scores) all_blocks = mtf.concat(all_blocks, new_tmp_dim.name) all_scores = mtf.concat(all_scores, new_tmp_dim.name) tf.logging.info(all_blocks) new_tmp_dim = all_blocks.shape.get_dim_by_name("extra_dim") combined_dim = _combined_dim([new_tmp_dim, tmp_dim]) block_net_shape = all_scores.shape - tmp_dim - new_tmp_dim + combined_dim block_net = mtf.reshape(all_scores, block_net_shape) if block_mixing_mode == "score_attention": tf.logging.info("Using score attention") att = mtf.einsum([block_net, block_net], reduced_dims=[new_tmp_dim]) tf.logging.info(block_net) att = mtf.softmax(att, reduced_dim=att.shape[-1]) block_net = mtf.einsum([att, block_net], output_shape=block_net.shape) tf.logging.info(block_net) if activation == "softmax": block_net = mtf.softmax(block_net, reduced_dim=new_tmp_dim) elif activation == "tanh": tf.logging.info("Using tanh") block_net = mtf.tanh(block_net) all_blocks = block_net * all_blocks all_blocks = mtf.reduce_sum(all_blocks, reduced_dim=new_tmp_dim) output = all_blocks if downsample: output_length = output.shape.get_dim_by_name("length") if output_length.size % int(downsample) != 0: pad_amt = int(downsample) - int( output_length.size % int(downsample)) output = mtf.pad(output, [0, pad_amt], output_length.name) if downsample_function == "mean": output = mtf.pool_tensor_1d( output, pool_dim=output.shape.get_dim_by_name("length"), reduce_fn=mtf.reduce_mean, pool_size=int(downsample)) else: raise ValueError("Downsampling function not implemeneted.") return output
'triangle_relax': lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin(5 * x) / 25 - mtf.sin( 7 * x) / 49, 'square_relax': lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos(5 * x) / 5 - mtf.cos( 7 * x) / 7, 'spike': lambda x: 1 / (1 + x**2), 'spike2': lambda x: mtf.exp(-x**2), 'tanhshrink': lambda x: x - tanh(x), 'softsign': lambda x: x / (mtf.abs(x) + 1), 'softmax': lambda x: mtf.softmax(x, x.shape[-1]), 'logsoftmax': lambda x: mtf.log_softmax(x, x.shape[-1]), 'bipolarsigmoid': lambda x: mtf.sigmoid(x) * 2 - 1, 'rrelu': _rrelu, 'elish': _elish, 'arcsinh': _arcsinh, 'aria': lambda x: x * (_var(x, 0) + _var(x, 1) / (_pos_var(x, 0) + _var(x, 1) * mtf.exp(_var(x, -1) * x)** (1 / _pos_var(x, 1)))), 'prelu':
def self_attention(self, x, attention_bias): """Performs multi-headed self-attention with output projection. Args: x: output of previous layer attention_bias: optional float32 Tensor broadcastable to shape x.shape - self.model_dim + self.memory_seq_dim to be added to attention logits. This may used to mask out padding regions of the memory. Returns: float Tensor with the same shape as x """ queries = mtf.layers.dense( x, reduced_dims=[self.model_dim], new_dims=[self.num_heads_dim, self.size_per_head_dim], kernel_initializer=self.dense_initializer, name="query", use_bias=self.config.use_bias) keys = mtf.layers.dense( mtf.replace_dimensions(x, self.seq_dim, self.memory_seq_dim), reduced_dims=[self.model_dim], new_dims=[self.num_heads_dim, self.size_per_head_dim], kernel_initializer=self.dense_initializer, name="key", use_bias=self.config.use_bias) values = mtf.layers.dense( mtf.replace_dimensions(x, self.seq_dim, self.memory_seq_dim), reduced_dims=[self.model_dim], new_dims=[self.num_heads_dim, self.size_per_head_dim], kernel_initializer=self.dense_initializer, name="value", use_bias=self.config.use_bias) # Take the dot product between "query" and "key" to get the raw # attention scores. attention_scores = mtf.einsum( [queries, keys], reduced_dims=[self.size_per_head_dim]) attention_scores *= self.size_per_head_dim.size ** -0.5 if attention_bias is not None: attention_scores += attention_bias # Normalize the attention scores to probabilities. attention_probs = mtf.softmax(attention_scores, self.memory_seq_dim) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = mtf.dropout( attention_probs, is_training=(self.config.attention_probs_dropout_prob == 0.0), keep_prob=1.0 - self.config.attention_probs_dropout_prob) output = mtf.einsum([attention_probs, values], reduced_dims=[self.memory_seq_dim]) # linear transformation back to shape of query_antecedent output = mtf.layers.dense( output, reduced_dims=[self.num_heads_dim, self.size_per_head_dim], new_dims=[self.model_dim], kernel_initializer=self.dense_initializer, name="output", use_bias=self.config.use_bias) output = mtf.transpose(output, x.shape) return output
def get_activation_fn(params): activation_fn = params.get("activation_fn", "gelu") def _arcsinh(x): return mtf.log(x + mtf.sqrt(1 + x**2)) def _var(x, init): return mtf.get_variable(x.mesh, f"activation-{random.randint(0, 2 ** 32):x}", [], initializer=tf.constant_initializer(init), dtype=x.dtype) def _pos_var(x, val): return mtf.softplus(_var(x, 0)) + val if activation_fn == "gelu": # https://arxiv.org/abs/1606.08415 return mtf.gelu elif activation_fn == "relu": return mtf.relu elif activation_fn == "sigmoid": return mtf.sigmoid elif activation_fn == "tanh": return mtf.tanh elif activation_fn == "selu": # https://arxiv.org/abs/1706.02515 return mtf.selu elif activation_fn == "elu": # https://arxiv.org/abs/1511.07289 return mtf.elu elif activation_fn == "lrelu001": return lambda x: mtf.leaky_relu(x, alpha=0.01) elif activation_fn == "lrelu020": return lambda x: mtf.leaky_relu(x, alpha=0.20) elif activation_fn == "abs": return mtf.abs elif activation_fn == "id": return lambda x: x elif activation_fn == "sin": return mtf.sin elif activation_fn == "cos": return mtf.cos elif activation_fn == "sign": return mtf.sign elif activation_fn == "triangle_relax": return lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin( 5 * x) / 25 - mtf.sin(7 * x) / 49 elif activation_fn == "square_relax": return lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos( 5 * x) / 5 - mtf.cos(7 * x) / 7 elif activation_fn == "spike": return lambda x: 1 / (1 + x**2) elif activation_fn == "spike2": return lambda x: mtf.exp(-x**2) elif activation_fn == "tanhshrink": return lambda x: x - tanh(x) elif activation_fn == "softsign": return lambda x: x / (mtf.abs(x) + 1) elif activation_fn == "softmax": return lambda x: mtf.softmax(x, x.shape[-1]) elif activation_fn == "logsoftmax": return lambda x: mtf.log_softmax(x, x.shape[-1]) elif activation_fn == "bipolarsigmoid": return lambda x: mtf.sigmoid(x) * 2 - 1 elif activation_fn == "rrelu": # https://arxiv.org/abs/1505.00853 def _rrelu_fn(x): negative_scale = random.random() return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale) return _rrelu_fn elif activation_fn == "elish": # https://arxiv.org/abs/1808.00783v1 def _elish_fn(x): cond = mtf.cast(mtf.greater(x, 0), x.dtype) exp = mtf.exp(x) return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp + 1) return _elish_fn elif activation_fn == "silu": # https://arxiv.org/abs/1710.05941 return mtf.swish elif activation_fn == "arcsinh": return _arcsinh # parametric elif activation_fn == "aria": # https://arxiv.org/abs/1805.08878 return lambda x: x * (_var(x, 0) + _var(x, 1) / (_pos_var(x, 0) + _var( x, 1) * mtf.exp(_var(x, -1) * x)**(1 / _pos_var(x, 1)))) elif activation_fn == "prelu": # https://arxiv.org/abs/1502.01852 return lambda x: mtf.leaky_relu(x, alpha=_var(x, 0.2)) elif activation_fn == "parcsinh": return lambda x: _var(x, 1) * _arcsinh(x * _pos_var(x, 1)) elif activation_fn == "psoftplus": return lambda x: _var(x, 1) * mtf.softplus(x * _var(x, 1)) + _var(x, 0) elif activation_fn == "proottanh": return lambda x: (x**_pos_var(x, 2) + _pos_var(x, 1))**(1 / _pos_var( x, 3)) * mtf.tanh(x) # https://arxiv.org/abs/1710.05941, https://arxiv.org/abs/1901.02671 elif activation_fn == "maxsig": return lambda x: mtf.maximum(x, mtf.sigmoid(x)) elif activation_fn == "cosid": return lambda x: mtf.cos(x) - x elif activation_fn == "minsin": return lambda x: mtf.minimum(x, mtf.sin(x)) elif activation_fn == "maxtanh": return lambda x: mtf.maximum(x, mtf.tanh(x)) elif activation_fn == "softplus": return mtf.softplus elif activation_fn == "mish": # https://arxiv.org/abs/1908.08681 return lambda x: x * mtf.tanh(mtf.softplus(x)) elif activation_fn == "tanhexp": # https://arxiv.org/abs/2003.09855 return lambda x: x * mtf.tanh(mtf.exp(x)) elif activation_fn == "lisht": # https://arxiv.org/abs/1901.05894 return lambda x: x * mtf.tanh(x) elif activation_fn == "seagull": # https://arxiv.org/abs/2011.11713 return lambda x: mtf.log(1 + x**2) elif activation_fn == "snake": # https://arxiv.org/abs/2006.08195 return lambda x: x + mtf.sin(x)**2 elif activation_fn == "roottanh": # made up return lambda x: (x**2 + 1)**(1 / 3) * mtf.tanh(x) elif activation_fn == "softplusmone": # made up return lambda x: mtf.softplus(x) - 1 else: raise ValueError( 'unknown activation function "activation_fn" in config')
def get_activation_fn(params): activation_fn = params.get("activation_fn", "gelu") if activation_fn == "gelu": # https://arxiv.org/abs/1606.08415 return mtf.gelu elif activation_fn == "relu": return mtf.relu elif activation_fn == "sigmoid": return mtf.sigmoid elif activation_fn == "tanh": return mtf.tanh elif activation_fn == "selu": # https://arxiv.org/abs/1706.02515 return mtf.selu elif activation_fn == "elu": # https://arxiv.org/abs/1511.07289 return mtf.elu elif activation_fn == "abs": return mtf.abs elif activation_fn == "id": return lambda x: x elif activation_fn == "sin": return mtf.sin elif activation_fn == "cos": return mtf.cos elif activation_fn == "sign": return mtf.sign elif activation_fn == "triangle_relax": return lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin( 5 * x) / 25 - mtf.sin(7 * x) / 49 elif activation_fn == "square_relax": return lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos( 5 * x) / 5 - mtf.cos(7 * x) / 7 elif activation_fn == "spike": return lambda x: 1 / (1 + x**2) elif activation_fn == "spike2": return lambda x: mtf.exp(-x**2) elif activation_fn == "tanhshrink": return lambda x: x - tanh(x) elif activation_fn == "softsign": return lambda x: x / (mtf.abs(x) + 1) elif activation_fn == "softmax": return lambda x: mtf.softmax(x, x.shape[-1]) elif activation_fn == "logsoftmax": return lambda x: mtf.log_softmax(x, x.shape[-1]) elif activation_fn == "bipolarsigmoid": return lambda x: mtf.sigmoid(x) * 2 - 1 elif activation_fn == "rrelu": # https://arxiv.org/abs/1505.00853 def _rrelu_fn(x): negative_scale = random.random() return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale) return _rrelu_fn elif activation_fn == "elish": # https://arxiv.org/abs/1808.00783v1 def _elish_fn(x): cond = mtf.cast(mtf.greater(x, 0), x.dtype) exp = mtf.exp(x) return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp + 1) return _elish_fn # swish activations elif activation_fn == "swish": # https://arxiv.org/abs/1710.05941 return mtf.swish # https://arxiv.org/abs/1710.05941, https://arxiv.org/abs/1901.02671 elif activation_fn == "maxsig": return lambda x: mtf.maximum(x, mtf.sigmoid(x)) elif activation_fn == "cosid": return lambda x: mtf.cos(x) - x elif activation_fn == "minsin": return lambda x: mtf.minimum(x, mtf.sin(x)) elif activation_fn == "maxtanh": return lambda x: mtf.maximum(x, mtf.tanh(x)) elif activation_fn == "softplus": return mtf.softplus elif activation_fn == "mish": # https://arxiv.org/abs/1908.08681 return lambda x: x * mtf.tanh(mtf.softplus(x)) elif activation_fn == "tanhexp": # https://arxiv.org/abs/2003.09855 return lambda x: x * mtf.tanh(mtf.exp(x)) elif activation_fn == "lisht": # https://arxiv.org/abs/1901.05894 return lambda x: x * mtf.tanh(x) elif activation_fn == "seagull": # https://arxiv.org/abs/2011.11713 return lambda x: mtf.log(1 + x**2) elif activation_fn == "snake": # https://arxiv.org/abs/2006.08195 return lambda x: x + mtf.sin(x)**2 elif activation_fn == "roottanh": # made up return lambda x: (x**2 + 1)**(1 / 3) * mtf.tanh(x) elif activation_fn == "softplusmone": # made up return lambda x: mtf.softplus(x) - 1 else: raise ValueError( 'unknown activation function "activation_fn" in config')
def _rand_1_gating( inputs, outer_expert_dims, experts_dim, expert_capacity_dim, hparams, train, variable_dtype, importance=None, name="rand_1_gating", num_microbatches=None): """Compute a random top-1 gating.""" # SELECT EXPERT if train: policy = hparams.moe_rand_1_policy_train else: policy = hparams.moe_rand_1_policy_eval # The internals of this function run in float32. # bfloat16 seems to reduce quality. gate_inputs = mtf.to_float(inputs) # Input perturbations if train and policy == "input_dropout": gate_inputs = mtf.dropout(gate_inputs, 1.0 - hparams.moe_rand_1_dropout) elif train and policy == "input_jitter": gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs, hparams.moe_rand_1_jitter) gate_logits = mtf.layers.dense( gate_inputs, experts_dim, use_bias=False, expert_dims=outer_expert_dims, variable_dtype=variable_dtype, name=name) raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim) if policy == "argmax" or policy == "input_dropout" or policy == "input_jitter": expert_gate, expert_index = mtf.top_1(raw_gates, reduced_dim=experts_dim) elif policy == "sample": expert_index = mtf.sample_with_temperature( gate_logits, experts_dim, temperature=hparams.moe_rand_1_temperature) expert_gate = mtf.gather(raw_gates, expert_index, dim=experts_dim) else: raise ValueError("Unknown rand_1 policy %s" % policy) expert_mask = mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype) # LOAD BALANCING LOSS # TODO(liamfedus): Check entropy loss. group_size_dim = inputs.shape[-2] density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim) density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim) if importance is not None: expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype) expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype) density_1_proxy *= mtf.cast( mtf.equal(importance, 1.0), dtype=raw_gates.dtype) loss = ( mtf.reduce_mean(density_1_proxy * density_1) * float(experts_dim.size * experts_dim.size)) if num_microbatches and num_microbatches > 1: tf.logging.info("Dividing load-balance loss by num_microbatches={}".format( num_microbatches)) loss /= num_microbatches # Logging if train: entropy = mtf.reduce_sum(-raw_gates * mtf.log(raw_gates + 1e-9), reduced_dim=experts_dim) batch_entropy = mtf.reduce_mean(entropy) mtf.scalar_summary(name + "/entropy", batch_entropy) mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim]) total_routed = mtf.reduce_sum(mask_count_experts) expert_fraction = mtf.to_float(mask_count_experts / total_routed) split_fractions = mtf.split( expert_fraction, split_dim=experts_dim, num_or_size_splits=experts_dim.size) for fraction in split_fractions: mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"), mtf.reduce_mean(fraction)) mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss)) # COMPUTE ASSIGNMENT TO EXPERT # Experts have a limited capacity, ensure we do not exceed it. Construct # the batch indices, to each expert, with position_in_expert position_in_expert = mtf.cumsum( expert_mask, group_size_dim, exclusive=True) * expert_mask position_in_expert = mtf.cast(position_in_expert, dtype=raw_gates.dtype) # Keep only tokens that fit within expert_capacity. expert_capacity_float = float(expert_capacity_dim.size) expert_mask *= mtf.cast( mtf.less(position_in_expert, expert_capacity_float), dtype=raw_gates.dtype) expert_mask_flat = mtf.reduce_sum(expert_mask, reduced_dim=experts_dim) # Mask out the experts that have overflowed expert capacity. Sparsify the # expert_gate. expert_gate *= expert_mask_flat combine_tensor = ( expert_gate * expert_mask_flat * mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype) * mtf.one_hot( mtf.to_int32(position_in_expert), expert_capacity_dim, dtype=raw_gates.dtype)) # Match the inputs dtype. combine_tensor = mtf.cast(combine_tensor, inputs.dtype) loss = mtf.cast(loss, inputs.dtype) dispatch_tensor = mtf.cast( mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype) return dispatch_tensor, combine_tensor, loss
def _top_2_gating(inputs, outer_expert_dims, experts_dim, expert_capacity_dim, hparams, train, variable_dtype, importance=None, name="top_2_gating"): """Compute gating for mixture-of-experts in TensorFlow. Note: until the algorithm and inferface solidify, we pass in a hyperparameters dictionary in order not to complicate the interface in mtf_transformer.py . Once this code moves out of "research", we should pass the hyperparameters separately. Hyperparameters used: hparams.moe_use_second_place_loss: a boolean hparams.moe_second_policy_train: a string hparams.moe_second_policy_eval: a string hparams.moe_second_threshold: a float The returned forward assignment is a tensor used to map (via einsum) from the inputs to the expert_inputs. Likewise, the returned combine_tensor is used to map (via einsum) from the expert outputs to the outputs. Both the forward and backward assignments are mostly zeros. The shapes of the tensors are as follows. inputs: [<batch_dims>, group_size_dim, input_dim] importance: [<batch_dims>, group_size_dim] dispatch_tensor: [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] expert_inputs: [<batch_dims>, experts_dim, expert_capacity_dim, input_dim] expert_outputs: [<batch_dims>, experts_dim, expert_capacity_dim, output_dim] combine_tensor: [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] outputs: [<batch_dims>, group_size_dim, output_dim] "importance" is an optional tensor with one floating-point value for each input vector. If the importance of an input is 1.0, then we send it to up to 2 experts. If 0.0 < importance < 1.0, then we send it to at most one expert. If importance == 0.0, then we send it to no experts. We use "importance" at the second-level gating function of a hierarchical mixture of experts. Inputs to the first-choice expert-group get importance 1.0. Inputs to the second-choice expert group get importance 0.5. Inputs that represent padding get importance 0.0. Args: inputs: a mtf.Tensor with shape [<batch_dims>, group_size_dim, input_dim] outer_expert_dims: an optional list of dimensions. This is for the case where we are at an inner level of a hierarchical MoE. experts_dim: a Dimension (the number of experts) expert_capacity_dim: a Dimension (number of examples per group per expert) hparams: model hyperparameters. train: a boolean variable_dtype: a mtf.VariableDType importance: an optional tensor with shape [<batch_dims>, group_size_dim] name: an optional string Returns: dispatch_tensor: a Tensor with shape [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] combine_tensor: a Tensor with shape [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim] loss: a mtf scalar Raises: ValueError: on illegal hyperparameters """ group_size_dim, unused_input_dim = inputs.shape.dims[-2:] raw_gates = mtf.layers.dense(inputs, experts_dim, use_bias=False, expert_dims=outer_expert_dims, variable_dtype=variable_dtype, name=name) raw_gates = mtf.softmax(raw_gates, experts_dim) # The internals of this function run in float32. # bfloat16 seems to reduce quality. raw_gates = mtf.to_float(raw_gates) expert_capacity_f = float(expert_capacity_dim.size) # FIND TOP 2 EXPERTS PER POSITON # Find the top expert for each position. shape=[batch, group] index_1, gate_1 = mtf.top_1(raw_gates, experts_dim) # [batch, group, experts] mask_1 = mtf.one_hot(index_1, experts_dim, dtype=raw_gates.dtype) density_1_proxy = raw_gates if importance is not None: mask_1 *= mtf.to_float(mtf.equal(importance, 1.0)) gate_1 *= mtf.to_float(mtf.equal(importance, 1.0)) density_1_proxy *= mtf.to_float(mtf.equal(importance, 1.0)) gates_without_top_1 = raw_gates * (1.0 - mask_1) # [batch, group] index_2, gate_2 = mtf.top_1(gates_without_top_1, experts_dim) # [batch, group, experts] mask_2 = mtf.one_hot(index_2, experts_dim, dtype=raw_gates.dtype) if importance is not None: mask_2 *= mtf.to_float(mtf.greater(importance, 0.0)) denom = gate_1 + gate_2 + 1e-9 gate_1 /= denom gate_2 /= denom # BALANCING LOSSES # shape = [batch, experts] # We want to equalize the fraction of the batch assigned to each expert density_1 = mtf.reduce_mean(mask_1, reduced_dim=group_size_dim) # Something continuous that is correlated with what we want to equalize. density_1_proxy = mtf.reduce_mean(density_1_proxy, reduced_dim=group_size_dim) loss = (mtf.reduce_mean(density_1_proxy * density_1) * float(experts_dim.size * experts_dim.size)) if hparams.moe_use_second_place_loss: # Also add a loss to encourage all experts to be used equally also as the # second-place expert. Experimentally, this seems to be a wash. # We want to equalize the fraction of the batch assigned to each expert: density_2 = mtf.reduce_mean(mask_2, reduced_dim=group_size_dim) # As a proxy for density_2, we renormalize the raw gates after the top one # has been removed. normalized = gates_without_top_1 / (mtf.reduce_sum( gates_without_top_1, reduced_dim=experts_dim) + 1e-9) density_2_proxy = mtf.reduce_mean(normalized, reduced_dim=group_size_dim) loss_2 = (mtf.reduce_mean(density_2_proxy * density_2) * float(experts_dim.size * experts_dim.size)) loss += loss_2 * 0.5 # Depending on the policy in the hparams, we may drop out some of the # second-place experts. if train: policy = hparams.moe_second_policy_train threshold = hparams.moe_second_threshold_train else: policy = hparams.moe_second_policy_eval threshold = hparams.moe_second_threshold_eval if policy == "all": # Use second-place experts for all examples. pass elif policy == "none": # Never use second-place experts for all examples. mask_2 = mtf.zeros_like(mask_2) elif policy == "threshold": # Use second-place experts if gate_2 > threshold. mask_2 *= mtf.to_float(mtf.greater(gate_2, threshold)) elif policy == "random": # Use second-place experts with probablity min(1.0, gate_2 / threshold). mask_2 *= mtf.to_float( mtf.less(mtf.random_uniform(gate_2.mesh, gate_2.shape), gate_2 / max(threshold, 1e-9))) else: raise ValueError("Unknown policy %s" % policy) # COMPUTE ASSIGNMENT TO EXPERTS # [batch, group, experts] # This is the position within the expert's mini-batch for this sequence position_in_expert_1 = mtf.cumsum(mask_1, group_size_dim, exclusive=True) * mask_1 # Remove the elements that don't fit. [batch, group, experts] mask_1 *= mtf.to_float(mtf.less(position_in_expert_1, expert_capacity_f)) # [batch, experts] # How many examples in this sequence go to this expert mask_1_count = mtf.reduce_sum(mask_1, reduced_dim=group_size_dim) # [batch, group] - mostly ones, but zeros where something didn't fit mask_1_flat = mtf.reduce_sum(mask_1, reduced_dim=experts_dim) # [batch, group] position_in_expert_1 = mtf.reduce_sum(position_in_expert_1, reduced_dim=experts_dim) # Weight assigned to first expert. [batch, group] gate_1 *= mask_1_flat # [batch, group, experts] position_in_expert_2 = ( mtf.cumsum(mask_2, group_size_dim, exclusive=True) + mask_1_count) position_in_expert_2 *= mask_2 mask_2 *= mtf.to_float(mtf.less(position_in_expert_2, expert_capacity_f)) # mask_2_count = mtf.reduce_sum(mask_2, reduced_dim=experts_dim) mask_2_flat = mtf.reduce_sum(mask_2, reduced_dim=experts_dim) gate_2 *= mask_2_flat position_in_expert_2 = mtf.reduce_sum(position_in_expert_2, reduced_dim=experts_dim) # [batch, group, experts, expert_capacity] combine_tensor = ( gate_1 * mask_1_flat * mtf.one_hot(index_1, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert_1), expert_capacity_dim) + gate_2 * mask_2_flat * mtf.one_hot(index_2, experts_dim) * mtf.one_hot(mtf.to_int32(position_in_expert_2), expert_capacity_dim)) combine_tensor = mtf.cast(combine_tensor, inputs.dtype) loss = mtf.cast(loss, inputs.dtype) dispatch_tensor = mtf.cast(mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype) return dispatch_tensor, combine_tensor, loss
def attention(q, k, v, memory_length_dim, key_dim, value_dim, bias=None, dropout_rate=0.0, dropout_broadcast_dims=None, extra_logit=None, context=None, float32_logits=True, z_loss_coeff=None): """Dot-product attention - doesn't use positional dimensions. key_dim is a Dimension representing the channels in the queries and keys value_dim is a Dimension representing the channels in values memory_length_dim is a Dimension representing the different key/value pairs. Dimensions of q: other_query_dims + {key_dim} Dimensions of k: other_memory_dims + {memory_length_dim, key_dim} Dimensions of v: other_memory_dims + {memory_length_dim, value_dim} other_memory_dims is a subset of other_query_dims Typically, other_query_dims={batch, heads, length} Typically, other_memory_dims={batch, heads} Args: q: a Tensor k: a Tensor v: a Tensor memory_length_dim: a Dimension key_dim: a Dimension value_dim: a Dimension bias: a Tensor to be added into the attention logits. dropout_rate: a float. dropout_broadcast_dims: an optional list of mtf.Dimension extra_logit: an optional scalar or tensor context: an optional Transformer.Context float32_logits: a boolean - if True, then compute logits in float32 to avoid numerical issues with bfloat16 z_loss_coeff: a float, if z_loss_coeff is not None then add an auxiliary loss to push the attention logits closer to zero. This helps to stabilize model training. Returns: Tensor with shape q.shape - key_dim + value_dim """ orig_q_shape = q.shape q, k, v, bias = maybe_reshape_attention_input_for_2d_sharding( context, q, k, v, bias, [key_dim, value_dim]) if float32_logits: k = mtf.cast(k, tf.float32) q = mtf.cast(q, tf.float32) logits = mtf.layers.us_einsum([q, k], reduced_dims=[key_dim]) if bias is not None: logits += mtf.cast(bias, logits.dtype) # Adds auxiliary z-loss to push the attention logits towards zero. if z_loss_coeff is not None and context.train: tf.logging.info("attention z_loss being added: {}".format( tf.get_variable_scope().name)) log_z = mtf.reduce_logsumexp(logits, memory_length_dim) z_loss = mtf.square(log_z) * mtf.cast(context.nonpadding, log_z.dtype) z_loss = mtf.reduce_mean(z_loss) if context.num_microbatches and context.num_microbatches > 1: tf.logging.info( "Dividing attention z-loss loss by num_microbatches={}".format( context.num_microbatches)) z_loss /= context.num_microbatches if context.train: mtf.scalar_summary("attention_z_loss", z_loss) z_loss *= z_loss_coeff context.losses.append(mtf.cast(z_loss, v.dtype)) weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit) weights = mtf.cast(weights, v.dtype) weights = mtf.dropout( weights, context.train, 1.0 - dropout_rate, noise_shape=weights.shape - dropout_broadcast_dims) outputs_shape = q.shape - key_dim + value_dim outputs = mtf.einsum([weights, v], outputs_shape) outputs = mtf.reshape(outputs, orig_q_shape - key_dim + value_dim) return outputs