def call(self, context, x, losses=None): """Call the layer.""" if context.model.ensemble_dim: raise NotImplementedError("MoE not yet implemented with ensembles") has_length_dim = context.length_dim in x.shape.dims if not has_length_dim: x_shape = x.shape shape_with_length = mtf.Shape(x_shape.dims[:-1] + [mtf.Dimension("length", 1)] + x_shape.dims[-1:]) x = mtf.reshape(x, shape_with_length) y, loss = transformer_moe_layer_v1(x, context.model.model_dim, self._hparams, context.train, context.variable_dtype, layout=context.model.layout, mesh_shape=context.model.mesh_shape, nonpadding=context.nonpadding) if context.losses is not None: context.losses.append(loss) if not has_length_dim: y = mtf.reshape(y, x_shape) return y
def widedeep(id_hldr, wt_hldr, vocab_dim, embed_dim, outdim, float16=None): logger.debug("[input tensor] (name,shape):({},{})".format(id_hldr.name,id_hldr.shape)) logger.debug("[input tensor] (name,shape):({},{})".format(wt_hldr.name,wt_hldr.shape)) if float16: deep_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim, variable_dtype=float16, name="deep_embedding") else: fp32 = mtf.VariableDType(tf.float32,tf.float32,tf.float32) deep_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim, variable_dtype=fp32, name="deep_embedding") logger.debug("[output tensor] (name,shape):({},{})".format(deep_output.name,deep_output.shape)) expend_dim = mtf.Dimension('expend',size=1) embed_dim_one = mtf.Dimension('embed_dim_one',size=1) mask = mtf.reshape(wt_hldr, new_shape=[wt_hldr.shape.dims[0],wt_hldr.shape.dims[1],expend_dim], name='mask_reshape') logger.debug("[output tensor] (name,shape):({},{})".format(mask.name,mask.shape)) if float16: wide_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim_one, variable_dtype=float16, name="wide_embedding") else: fp32 = mtf.VariableDType(tf.float32,tf.float32,tf.float32) wide_output = mtf.layers.embedding(id_hldr, vocab_dim=vocab_dim, output_dim=embed_dim_one, variable_dtype=fp32, name="wide_embedding") logger.debug("[output tensor] (name,shape):({},{})".format(wide_output.name,wide_output.shape)) wide_output = wide(wide_output,mask=mask,float16=float16) deep_output = deep(deep_output,mask=mask,float16=float16) result = mtf.add(wide_output,deep_output) result = mtf.reshape(result, new_shape=[wide_output.shape.dims[0],outdim],name='result_reshape') logger.debug("[output tensor] (name,shape):({},{})".format(result.name, result.shape)) return result
def call(self, context, x, losses=None): """Call the layer.""" # Dim cheat sheet: # <B>: batch dims, e.g. # [outer_batch_size, batch_size] or # [beam_size, batch_size] # L: original length # M: model dim # # x # <B>LM Tensor has_length_dim = context.length_dim in x.shape.dims if not has_length_dim: x_shape = x.shape shape_with_length = mtf.Shape(x_shape.dims[:-1] + [mtf.Dimension("length", 1)] + x_shape.dims[-1:]) x = mtf.reshape(x, shape_with_length) y, loss = transformer_moe_layer_v1(x, context.model.model_dim, self._hparams, context.train, context.variable_dtype, layout=context.model.layout, mesh_shape=context.model.mesh_shape, nonpadding=context.nonpadding) if context.losses is not None: context.losses.append(loss) if not has_length_dim: y = mtf.reshape(y, x_shape) return y
def local_attention1d_spatial_decoder(x, kv_dim, heads_dim, feedforward_dim, hparams): """Image Transformer decoder with local1D spatial layers.""" batch_dim, length_dim, model_dim = x.shape.dims blocks_w_dim = mtf.Dimension("blocksw", hparams.block_length) num_w_blocks_dim = mtf.Dimension("num_wblocks", length_dim.size // blocks_w_dim.size) x = mtf.reshape( x, mtf.Shape([batch_dim, num_w_blocks_dim, blocks_w_dim, model_dim])) # [ self attention - ffn - residual + dropout] x n for layer in range(hparams.num_decoder_layers): layer_name = "decoder_layer_%d" % layer with tf.variable_scope(layer_name): # Self attention layer x += layer_prepostprocess_dropout( mtf.layers.local_self_attention_spatial_blocks( mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"), kv_dim, heads_dim, memory_w_dim=blocks_w_dim, mask_right=True, name="self_att"), hparams) # ffn layer x += layer_prepostprocess_dropout( mtf.layers.dense_relu_dense( mtf.layers.layer_norm(x, model_dim, name="layer_norm_ffn"), feedforward_dim, hparams.dropout, dropout_broadcast_dims=[length_dim]), hparams) output = mtf.layers.layer_norm(x, model_dim, name="final_layer_norm") return output
def call(self, context, x, losses=None): """Call the layer.""" io_channels = x.shape.dims[-1] hidden_channels = mtf.Dimension("d_ff", self.hidden_size) h = dense_product_fixup( x, reduced_dims=x.shape.dims[-1:], new_dims=hidden_channels, activation_functions=self.activation, use_bias=self.use_bias, variable_dtype=context.variable_dtype, name="wi", kernel_initializer=self.upproject_initializer, expert_dims=context.model.ensemble_dims) if context.train and self.dropout_rate != 0.0: h = mtf.dropout( h, 1.0 - self.dropout_rate, noise_shape=h.shape - context.length_dim) shift = get_single_scalar_bias(x, "shift") h_res = mtf.add(h, shift) h = mtf.reshape(h_res, h.shape) return mtf.layers.dense( h, io_channels, use_bias=self.use_bias, activation=None, variable_dtype=context.variable_dtype, reduced_dims=h.shape.dims[-1:], name="wo", expert_dims=context.model.ensemble_dims, kernel_initializer=self.downproject_initializer)
def _repeat(x, n, repeat_dim): # repeat function in MTF tmp_dim = mtf.Dimension("tmp", 1) expand_shape = mtf.Shape(x.shape.dims + [tmp_dim]) x = mtf.reshape(x, expand_shape) x = _tile(x, n, tmp_dim) output_shape = [] for dim in x.shape.dims: if dim.name == "tmp": continue if dim.name == repeat_dim.name: dim = mtf.Dimension(dim.name, dim.size * n) output_shape.append(dim) output_shape = mtf.Shape(output_shape) x = mtf.reshape(x, output_shape) return x
def axial_positional_emb(embd_dim, mesh, params, variable_dtype): # Use axial position encoding axial_dim_1, axial_dim_2 = params["axial_pos_emb"] axial_dim = mtf.Dimension("axial_dim", axial_dim_1 * axial_dim_2) dim_axials = [mtf.Dimension(f"axial_dim_{i}", t) for i, t in enumerate((axial_dim_1, axial_dim_2))] axial_wpe_1 = mtf.get_variable(mesh, "axial_wpe_1", mtf.Shape([dim_axials[0], embd_dim]), initializer=tf.random_normal_initializer(stddev=0.01), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) axial_wpe_2 = mtf.get_variable(mesh, "axial_wpe_2", mtf.Shape([dim_axials[1], embd_dim]), initializer=tf.random_normal_initializer(stddev=0.01), master_dtype=variable_dtype.master_dtype, slice_dtype=variable_dtype.slice_dtype, activation_dtype=variable_dtype.activation_dtype) axial_wpe_1, axial_wpe_2 = map(lambda t: mtf.broadcast(t, [dim_axials[0], dim_axials[1], embd_dim]), (axial_wpe_1, axial_wpe_2)) wpe = (axial_wpe_1 + axial_wpe_2) / 2 wpe = mtf.reshape(wpe, [axial_dim, embd_dim]) return wpe
def downsample_hr_to_lr(field, lr_shape, hr_shape, downsampling_factor, halo_size, splittables, mesh): # Reshaping array into high resolution mesh field = mtf.reshape(field, field.shape+[mtf.Dimension('h_dim', 1)]) low = mesh_utils.downsample(field, downsampling_factor, antialias=True) low = mtf.reshape(low, low.shape[:-1]) for block_size_dim in hr_shape[-3:]: low = mtf.slice(low, halo_size//2**downsampling_factor, block_size_dim.size//2**downsampling_factor, block_size_dim.name) # Hack usisng custom reshape because mesh is pretty dumb low = mtf.slicewise(lambda x: x[:,0,0,0], [low], output_dtype=field.dtype, output_shape=lr_shape, name='my_dumb_reshape', splittable_dims=splittables) return low
def reshape(x, new_shape): old_shape = x.shape assert len(old_shape) == len(new_shape) for o, n in zip(old_shape.dims, new_shape.dims): if (o.name != n.name) and (o.name.startswith('axis') and n.name.startswith('axis')): x = mtf.rename_dimension(x, o.name, utils.RandName()) return mtf.reshape(x, new_shape)
def attention(q, k, v, memory_length_dim, key_dim, value_dim, bias=None, dropout_rate=0.0, dropout_broadcast_dims=None, extra_logit=None, context=None): """Dot-product attention - doesn't use positional dimensions. key_dim is a Dimension representing the channels in the queries and keys value_dim is a Dimension representing the channels in values memory_length_dim is a Dimension representing the different key/value pairs. Dimensions of q: other_query_dims + {key_dim} Dimensions of k: other_memory_dims + {memory_length_dim, key_dim} Dimensions of v: other_memory_dims + {memory_length_dim, value_dim} other_memory_dims is a subset of other_query_dims Typically, other_query_dims={batch, heads, length} Typically, other_memory_dims={batch, heads} Args: q: a Tensor k: a Tensor v: a Tensor memory_length_dim: a Dimension key_dim: a Dimension value_dim: a Dimension bias: a Tensor to be added into the attention logits. dropout_rate: a float. dropout_broadcast_dims: an optional list of mtf.Dimension extra_logit: an optional scalar or tensor context: an optional Transformer.Context Returns: Tensor with shape q.shape - key_dim + value_dim """ orig_q_shape = q.shape q, k, v, bias = maybe_reshape_attention_input_for_2d_sharding( context, q, k, v, bias, [key_dim, value_dim]) logits = mtf.layers.us_einsum([q, k], reduced_dims=[key_dim]) if bias is not None: logits += bias weights = mtf.softmax(logits, memory_length_dim, extra_logit=extra_logit) if dropout_rate != 0.0: weights = mtf.dropout(weights, 1.0 - dropout_rate, noise_shape=weights.shape - dropout_broadcast_dims) outputs_shape = q.shape - key_dim + value_dim outputs = mtf.einsum([weights, v], outputs_shape) outputs = mtf.reshape(outputs, orig_q_shape - key_dim + value_dim) return outputs
def sublayer_fixup_scale(x, layer_stack, context): """Multiply by single one-initialized scalar.""" del layer_stack dim = mtf.Dimension("single_scale", 1) fixup_weight = mtf.get_variable( x.mesh, "fixup_scale_weight", shape=mtf.Shape([dim]), dtype=context.variable_dtype, initializer=tf.constant_initializer(1.)) return mtf.reshape(x * fixup_weight, x.shape)
def mtf_model_fn(self, features, mesh): logits, loss = self._mtf_model_fn(features, mesh) # combine batch dims if len(self.batch_dims) > 1: combined_batch_dim = mtf.Dimension(self.batch_dims[0].name, mtf.Shape(self.batch_dims).size) logits = mtf.reshape(logits, [combined_batch_dim] + logits.shape.dims[-2:]) return logits, loss
def mtf_model_fn(self, features, mesh): with tf.variable_scope("transformer"): logits, loss = self._mtf_model_fn(features, mesh) # combine batch dims if len(self.batch_dims) > 1: combined_batch_dim = mtf.Dimension( self.batch_dims[0].name, mtf.Shape(self.batch_dims).size) logits = mtf.reshape( logits, [combined_batch_dim] + logits.shape.dims[-2:]) return logits, loss
def sublayer_fixup_shift(x, layer_stack, context): """Shift by single zero-initialized scalar.""" del layer_stack dim = mtf.Dimension("single_bias", 1) fixup_bias = mtf.get_variable( x.mesh, "fixup_bias", shape=mtf.Shape([dim]), dtype=context.variable_dtype, initializer=tf.zeros_initializer()) res = mtf.add(x, fixup_bias) res = mtf.reshape(res, x.shape) return res
def split_scales(field, downsampling_factor=2., antialias=True): """ Performs a multiresolution decomposition of the input field. The input field will be decomposed into a low resolution approximation, and a details component. """ low = downsample(field, downsampling_factor, antialias) high = upsample(low, downsampling_factor) high = field - mtf.reshape(high, field.shape) return low, high
def call(self, context, x, losses=None): """Call the layer.""" has_length_dim = context.length_dim in x.shape.dims if not has_length_dim: x_shape = x.shape shape_with_length = mtf.Shape( x_shape.dims[:-1] + [mtf.Dimension("length", 1)] + x_shape.dims[-1:]) x = mtf.reshape(x, shape_with_length) y, loss = transformer_moe_layer_v1( x, context.model_dim, self._hparams, context.train, context.variable_dtype) if context.losses is not None: context.losses.append(loss) if not has_length_dim: y = mtf.reshape(y, x_shape) return y
def attention(x, dim_head, dim_features_head, scope='attn', causal=False): with tf.variable_scope(scope): mesh, batch, seq, dim = x.mesh, *x.shape dim_heads = mtf.Dimension('dim_heads', dim_head.size * dim_features_head.size) dim_intermediate = mtf.Dimension('qkv_dimension', dim_heads.size * 3) qkv = linear(x, dim_intermediate, bias=False, scope='to_qkv') q, k, v = mtf.split(qkv, dim_intermediate, 3) q, k, v = map( lambda t: mtf.reshape(t, [batch, seq, dim_head, dim_features_head] ), (q, k, v)) q, k, v = map( lambda t: mtf.transpose( t, [batch, dim_head, seq, dim_features_head]), (q, k, v)) k, v = map( lambda t: mtf.rename_dimension(t, seq.name, 'memory_length'), (k, v)) mem_len_dim = v.shape[-2] dots = mtf.layers.us_einsum([q, k], [batch, dim_head, seq, mem_len_dim]) if causal: i = mtf.range(mesh, seq, tf.int32) j = mtf.range(mesh, mem_len_dim, tf.int32) i, j = map(lambda t: mtf.broadcast(t, [seq, mem_len_dim]), (i, j)) mask = mtf.less(i + mem_len_dim.size - seq.size, j) mask = mtf.cast(mask, tf.float32) * -1e10 dots += mask attn = mtf.softmax(dots, mem_len_dim) out = mtf.einsum([attn, v], [batch, dim_head, seq, dim_features_head]) out = mtf.transpose(out, [batch, seq, dim_head, dim_features_head]) out = mtf.reshape(out, [batch, seq, dim_heads]) combined_out = linear(out, dim, scope='combine_output') return combined_out
def _compute_output(hidden, layer_name): """Compute the output of the attention layer from the hidden vector.""" expert_output = mtf.layers.dense( hidden, output_dim, expert_dims=[experts_dim], use_bias=False, reduced_dims=hidden.shape.dims[-1:], variable_dtype=variable_dtype, name=layer_name) expert_output = mtf.reshape( expert_output, mtf.Shape([ outer_batch_dim, experts_dim_unsplit, num_groups_dim, expert_capacity_dim, output_dim, ])) moe_output_dims = moe_input_dims[:-1] + [output_dim] output = mtf.einsum([expert_output, combine_tensor], mtf.Shape(moe_output_dims)) output = mtf.reshape(output, batch_and_length_dims + [output_dim]) return output
def BasicBlock(x, order, out_channels, strides): name = "BasicBlock" expansion = 1 out_chls = out_channels // expansion identity = x x = mtf.layers.conv2d(x, output_dim=mtf.Dimension( name=name + '-' + str(order) + '-' + 'filters1', size=out_chls), filter_size=(3, 3), strides=strides, name="conv3x3_BB_1" + '-' + str(order), variable_dtype=float16) print(x.name) print(x.dtype) x, _ = mtf.layers.batch_norm(x, is_training=True, momentum=0.99, epsilon=1e-5, name="batch_norm_BB_1" + '-' + str(order)) x = mtf.relu(x, name="relu_BB_1" + '-' + str(order)) x = mtf.layers.conv2d(x, output_dim=mtf.Dimension( name=name + '-' + str(order) + '-' + 'filters2', size=out_channels), filter_size=(3, 3), strides=(1, 1), name="conv3x3_BB_2" + '-' + str(order), variable_dtype=float16) print(x.name) print(x.dtype) x, _ = mtf.layers.batch_norm(x, is_training=True, momentum=0.99, epsilon=1e-5, name="batch_norm_BB_2" + '-' + str(order)) identity = mtf.reshape(identity, new_shape=[ identity.shape.dims[0], identity.shape.dims[1], identity.shape.dims[2], x.shape.dims[3] ], name="reshape_BB" + str(order)) x = mtf.add(x, identity, output_shape=x.shape, name="add_BB_1" + '-' + str(order)) x = mtf.relu(x, name="relu_BB_2" + '-' + str(order)) print(x.name) print(x.dtype) return x
def _get_decoder_inputs(self, context): """Computes the inputs to the decoder when using transparent attention. We must cache on the context in order to ensure that we are not replicating variables when the layer's call function is called in different tf variable scopes. Args: context: a Context Returns: a list containing `self.num_decoder_modules` of tensors with shape [<batch_dims>, length_dim, output_vocab_dim] """ if hasattr(context, "decoder_layers_per_module"): return context.decoder_layers_per_module encoder_layer_outputs = [ mtf.layers.rename_length_to_memory_length(output) for output in context.encoder_layer_outputs ] layers_per_module = self.layers_per_encoder_module encoder_module_outputs_dim = mtf.Dimension( "encoder_module_outputs", size=self.encoder_num_modules + 1) decoder_module_inputs_dim = mtf.Dimension( "decoder_module_inputs", size=self.decoder_num_modules) encoder_module_outputs = mtf.stack( [encoder_layer_outputs[0]] + encoder_layer_outputs[layers_per_module::layers_per_module], dim_name="encoder_module_outputs") w = mtf.get_variable( context.mesh, "w", mtf.Shape([encoder_module_outputs_dim, decoder_module_inputs_dim]), initializer=tf.random_normal_initializer( stddev=(encoder_module_outputs_dim.size * decoder_module_inputs_dim.size)**-0.5), dtype=context.variable_dtype) if context.train and self.dropout_rate != 0.0: w = mtf.dropout(w, 1.0 - self.dropout_rate) s = mtf.softmax(w, reduced_dim=encoder_module_outputs_dim) z = mtf.einsum([s, encoder_module_outputs], reduced_dims=[encoder_module_outputs_dim]) input_per_decoder = mtf.split( z, split_dim=decoder_module_inputs_dim, num_or_size_splits=decoder_module_inputs_dim.size) context.decoder_layers_per_module = [ mtf.reshape(inpt, z.shape.dims[1:]) for inpt in input_per_decoder ] return context.decoder_layers_per_module
def call(self, context, x, losses=None): """Call the layer.""" if context.model.ensemble_dim: raise NotImplementedError("MoE not yet implemented with ensembles") has_length_dim = context.length_dim in x.shape.dims if not has_length_dim: x_shape = x.shape shape_with_length = mtf.Shape( x_shape.dims[:-1] + [mtf.Dimension("length", 1)] + x_shape.dims[-1:]) x = mtf.reshape(x, shape_with_length) # Extract the MoE output dimension if self._hparams.moe_output_dim is not None: output_dim = self._hparams.moe_output_dim else: output_dim = context.model.model_dim y, loss = transformer_moe_layer_v1( x, output_dim, self._hparams, context.train, context.variable_dtype, layout=context.model.layout, mesh_shape=context.model.mesh_shape, nonpadding=context.nonpadding, activation=self._activation, num_microbatches=context.num_microbatches) if context.losses is not None: context.losses.append(loss) if not has_length_dim: if self._hparams.moe_use_experts_attention: y_reshape = [mtf.reshape(y_out, x_shape) for y_out in y] y = y_reshape else: y = mtf.reshape(y, x_shape) return y
def local_attention2d_spatial_decoder(x, kv_dim, heads_dim, feedforward_dim, hparams): """Image Transformer decoder with local2D spatial layers.""" batch_dim, length_dim, model_dim = x.shape.dims blocks_h_dim = mtf.Dimension("blocksh", hparams.block_height) blocks_w_dim = mtf.Dimension("blocksw", hparams.block_width) num_h_blocks_dim = mtf.Dimension("num_h_blocks", hparams.img_len // hparams.block_height) num_w_blocks_dim = mtf.Dimension( "num_w_blocks", hparams.img_len * hparams.num_channels // hparams.block_width) x = mtf.transpose( mtf.reshape( x, mtf.Shape([ batch_dim, num_h_blocks_dim, blocks_h_dim, num_w_blocks_dim, blocks_w_dim, model_dim ])), mtf.Shape([ batch_dim, num_h_blocks_dim, num_w_blocks_dim, blocks_h_dim, blocks_w_dim, model_dim ])) mode = getattr(hparams, "mode", tf_estimator.ModeKeys.TRAIN) is_training = mode == tf_estimator.ModeKeys.TRAIN # Image Transformer Decoder # [ self attention - ffn - residual + dropout] x n for layer in range(hparams.num_decoder_layers): layer_name = "decoder_layer_%d" % layer with tf.variable_scope(layer_name): # Self attention layer x += layer_prepostprocess_dropout( mtf.layers.local_2d_self_attention_spatial_blocks( mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"), kv_dim, heads_dim, is_training, memory_h_dim=num_h_blocks_dim, memory_w_dim=num_w_blocks_dim, name="self_att"), hparams) # ffn layer x += layer_prepostprocess_dropout( mtf.layers.dense_relu_dense( mtf.layers.layer_norm(x, model_dim, name="layer_norm_ffn"), feedforward_dim, hparams.dropout, dropout_broadcast_dims=[length_dim]), hparams) output = mtf.layers.layer_norm(x, model_dim, name="final_layer_norm") return output
def local_attention2d_spatial_decoder(x, kv_dim, heads_dim, feedforward_dim, hparams): """Image Transformer decoder with local2D spatial layers.""" batch_dim, length_dim, model_dim = x.shape.dims blocks_h_dim = mtf.Dimension("blocksh", hparams.block_height) blocks_w_dim = mtf.Dimension("blocksw", hparams.block_width) num_h_blocks_dim = mtf.Dimension("num_h_blocks", hparams.img_len // hparams.block_height) num_w_blocks_dim = mtf.Dimension( "num_w_blocks", hparams.img_len * hparams.num_channels // hparams.block_width) x = mtf.transpose( mtf.reshape( x, mtf.Shape([ batch_dim, num_h_blocks_dim, blocks_h_dim, num_w_blocks_dim, blocks_w_dim, model_dim ])), mtf.Shape([ batch_dim, num_h_blocks_dim, num_w_blocks_dim, blocks_h_dim, blocks_w_dim, model_dim ])) # Image Transformer Decoder # [ self attention - ffn - residual + dropout] x n for layer in range(hparams.num_decoder_layers): layer_name = "decoder_layer_%d" % layer with tf.variable_scope(layer_name): # Self attention layer x += layer_prepostprocess_dropout( mtf.layers.local_2d_self_attention_spatial_blocks( mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"), kv_dim, heads_dim, memory_h_dim=num_h_blocks_dim, memory_w_dim=num_w_blocks_dim, name="self_att"), hparams) # ffn layer x += layer_prepostprocess_dropout( mtf.layers.dense_relu_dense( mtf.layers.layer_norm(x, model_dim, name="layer_norm_ffn"), feedforward_dim, hparams.dropout, dropout_broadcast_dims=[length_dim]), hparams) output = mtf.layers.layer_norm(x, model_dim, name="final_layer_norm") return output
def VGG(x, classes_dim, depth, batch_norm=True): if depth not in vgg_dict.keys(): print("VGG-{} are not supported!".format(depth)) raise ValueError x = make_conv_layers(x, mode=vgg_dict[depth], batch_norm=batch_norm) x = mtf.reshape( x, new_shape=[ x.shape.dims[0], mtf.Dimension(name="flatten", size=x.shape.dims[1].size * x.shape.dims[2].size * x.shape.dims[3].size) ], name="flatten") x = make_dense_layers(x, classes_dim=classes_dim) print(x.name) print(x.shape) return x
def model_fn(features, labels, mode, params): """A model is called by TpuEstimator.""" del labels del features mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape) layout_rules = mtf.convert_to_layout_rules(FLAGS.layout) ctx = params['context'] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)] tf.logging.info('device_list = %s' % device_list, ) mesh_devices = [''] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) graph = mtf.Graph() mesh = mtf.Mesh(graph, "fft_mesh") with mtf.utils.outside_all_rewrites(): field = nbody_model(mesh) batch_dim, x_dim, y_dim, z_dim = field.shape x_dim_nosplit = mtf.Dimension("nx_nosplit", FLAGS.cube_size) y_dim_nosplit = mtf.Dimension("ny_nosplit", FLAGS.cube_size) # Until we implement distributed outputs, we only return one example field_slice, _ = mtf.split(field, batch_dim, [1, FLAGS.batch_size - 1]) field_slice = mtf.reshape( field_slice, [mtf.Dimension("bs", 1), x_dim_nosplit, y_dim_nosplit, z_dim]) #field_slice = field lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_field = tf.to_float(lowering.export_to_tf_tensor(field_slice)) with mtf.utils.outside_all_rewrites(): return tpu_estimator.TPUEstimatorSpec(mode, predictions={'field': tf_field})
def transformer_moe_layer_v1(inputs, output_dim, hparams, train, variable_dtype, layout=None, mesh_shape=None, nonpadding=None): """Local mixture of experts that works well on TPU. Adapted from the paper https://arxiv.org/abs/1701.06538 Note: until the algorithm and inferface solidify, we pass in a hyperparameters dictionary in order not to complicate the interface in mtf_transformer.py . Once this code moves out of "research", we should pass the hyperparameters separately. Hyperparameters used: hparams.moe_num_experts: number of experts hparams.moe_hidden_size: size of hidden layer in each expert hparams.moe_group_size: size of each "group" for gating purposes hparams.moe_capacity_factor_train: a float hparams.moe_capacity_factor_eval: a float hparams.moe_gating: a string + all hyperparmeters used by _top_2_gating() The number of parameters in the gating network is: (input_dim.size * hparams.num_experts) + The number of parameters in the experts themselves is: (hparams.num_experts * (input_dim.size + output_dim.size) * hparams.moe_hidden_size) The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting of the representations of all positions in a batch of sequences. Each position of each sequence is sent to 0-2 experts. The expert choices and the combination weights are determined by a learned gating function. This function returns a small auxiliary loss that should be added to the training loss of the model. This loss helps to balance expert usage. Without the loss, it is very likely that a few experts will be trained and the rest will starve. Several hacks are necessary to get around current TPU limitations: - To ensure static shapes, we enforce (by truncation/padding) that each sequence send the same number of elements to each expert. It would make more sense to enforce this equality over the entire batch, but due to our hacked-up gather-by-matmul implementation, we need to divide the batch into "groups". For each group, the same number of elements are sent to each expert. TODO(noam): Factor this code better. We want to be able to substitute different code for the experts themselves. Dimensions cheat sheet: <B>: batch dims L: original sequence length M: input depth N: output depth G: number of groups S: group size E: number of experts C: expert capacity (u for unsplit dims) Args: inputs: a mtf.Tensor with shape [<batch_dims...>, length_dim, input_dim] output_dim: a mtf.Dimension (for Transformer, this is input_dim) hparams: model hyperparameters train: a boolean variable_dtype: a mtf.VariableDType layout: optional - an input to mtf.convert_to_layout_rules mesh_shape: optional - an input to mtf.convert_to_shape nonpadding: an optional Tensor with shape [<batch_dims>, length_dim] and the same dtype as inputs, consisting of ones(nonpadding) and zeros(padding). Returns: outputs: a Tensor with shape [<batch_dims...>, length_dim, output_dim] loss: a mtf scalar Raises: ValueError: on unrecognized hparams.moe_gating """ # See "Dimensions cheat sheet" # <B>LM Tensor orig_inputs = inputs hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size) experts_dim = mtf.Dimension("experts", hparams.moe_num_experts) # We "cheat" here and look at the mesh shape and layout. This is to ensure # that the number of groups is a multiple of the mesh dimension # over which those groups are split. batch_and_length_dims, input_dim = (orig_inputs.shape.dims[:-1], orig_inputs.shape.dims[-1]) # Hack: we assume that # "outer_batch" == replication of experts # mesh_dim_size can be derived from mesh_shape and orig_batch_dim # # We then reqire num_groups to be a multiple of mesh_dim_size. if orig_inputs.shape.dims[0].name == "outer_batch": outer_batch_dim, orig_batch_dim = orig_inputs.shape.dims[:2] else: outer_batch_dim, orig_batch_dim = (mtf.Dimension("outer_batch", 1), orig_inputs.shape.dims[0]) # Number of MoE inputs (total number of position across batch_and_length_dims # per replica. n = 1 for d in batch_and_length_dims: n *= d.size n = n // outer_batch_dim.size mesh_dim_size = mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, orig_batch_dim) num_groups, group_size = _split_into_groups(n, hparams.moe_group_size, mesh_dim_size) group_size_dim = mtf.Dimension("group", group_size) num_groups_dim = mtf.Dimension(orig_batch_dim.name, num_groups) moe_input_dims = [ outer_batch_dim, num_groups_dim, group_size_dim, input_dim ] # OGSM Tensor inputs = mtf.reshape(inputs, moe_input_dims) # Each sequence sends expert_capacity positions to each expert. if train: capacity_factor = hparams.moe_capacity_factor_train else: capacity_factor = hparams.moe_capacity_factor_eval expert_capacity = min( group_size_dim.size, int((group_size_dim.size * capacity_factor) / experts_dim.size)) expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity) experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size) batch_dim_unsplit = mtf.Dimension("batch_unsplit", num_groups_dim.size) if nonpadding is not None: nonpadding = mtf.zeros(inputs.mesh, batch_and_length_dims, dtype=inputs.dtype) + nonpadding nonpadding = mtf.reshape(nonpadding, moe_input_dims[:-1]) if hparams.moe_gating == "top_2": # dispatch_tensor and combine_tensor are # <B>GSEC Tensors dispatch_tensor, combine_tensor, loss = _top_2_gating( inputs=inputs, outer_expert_dims=None, experts_dim=experts_dim_unsplit, expert_capacity_dim=expert_capacity_dim, hparams=hparams, train=train, variable_dtype=variable_dtype, importance=nonpadding) else: raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating) expert_inputs = mtf.einsum([inputs, dispatch_tensor], mtf.Shape([ outer_batch_dim, experts_dim_unsplit, num_groups_dim, expert_capacity_dim, input_dim ])) expert_inputs = mtf.reshape( expert_inputs, mtf.Shape([ outer_batch_dim, experts_dim, batch_dim_unsplit, expert_capacity_dim, input_dim ])) # Now feed the expert inputs through the experts. h = mtf.layers.dense(expert_inputs, hidden_dim, expert_dims=[experts_dim], activation=mtf.relu, use_bias=False, variable_dtype=variable_dtype, name="wi") expert_output = mtf.layers.dense(h, output_dim, expert_dims=[experts_dim], use_bias=False, variable_dtype=variable_dtype, name="wo") expert_output = mtf.reshape( expert_output, mtf.Shape([ outer_batch_dim, experts_dim_unsplit, num_groups_dim, expert_capacity_dim, output_dim, ])) moe_output_dims = moe_input_dims[:-1] + [output_dim] output = mtf.einsum([expert_output, combine_tensor], mtf.Shape(moe_output_dims)) output = mtf.reshape(output, batch_and_length_dims + [output_dim]) return output, loss * hparams.moe_loss_coef
def _sample(self, features, mesh): hparams = self._hparams (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "encdec": inputs = features["inputs"] while len(inputs.shape.as_list()) > 2: inputs = tf.squeeze(inputs, axis=2) actual_batch_size = tf.shape(inputs)[0] actual_length = tf.shape(inputs)[1] inputs = tf.pad(inputs, [[0, hparams.batch_size - actual_batch_size], [0, hparams.max_length - actual_length]]) inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.reshape(positional_embedding_var, mtf.Shape([self.length_dim, self.model_dim]))) encoder_attention_mask = (mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) with tf.variable_scope("encoder"): x = self._layer_stack( x, hparams.encoder_layers, self_attention_mask=encoder_attention_mask) encoder_output = mtf.rename_dimension(x, self.length_dim.name, self.memory_length_dim.name) encdec_tensors = [] for layer_num, layer_type in enumerate(hparams.decoder_layers): if layer_type == "enc_att": with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num): q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars( mesh, self.heads_dim, self.model_dim, self.kv_dim, self.master_dtype, self.slice_dtype, self.activation_dtype) k = mtf.einsum([encoder_output, k_var], mtf.Shape(self.batch_dims + [ self.heads_dim, self.memory_length_dim, self.kv_dim ])) v = mtf.einsum([encoder_output, v_var], mtf.Shape(self.batch_dims + [ self.heads_dim, self.memory_length_dim, self.kv_dim ])) encdec_tensors.append((q_var, o_var, k, v)) else: encdec_tensors.append(None) partial_targets = None elif hparams.transformer_type == "decoder": encdec_tensors = None encoder_output = None encoder_attention_mask = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd( partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, hparams.max_length - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh, hparams) else: raise ValueError("hparams.model_type = %s not yet supported" % hparams.transformer_type) local_attention_window = mtf.Dimension( "local_attention_window", hparams.local_attention_window_size) if hparams.beam_size == 1: ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) kv_shape = mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape( self.batch_dims + [self.heads_dim, local_attention_window, self.kv_dim]) else: beam_dim = mtf.Dimension("beam", hparams.beam_size) ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [ beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim ]) local_kv_shape = mtf.Shape(self.batch_dims + [ beam_dim, self.heads_dim, local_attention_window, self.kv_dim ]) initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) initial_states = [] for layer in hparams.decoder_layers: if layer == "att": initial_states.extend( [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2) elif layer == "local_att": initial_states.extend([ mtf.zeros( mesh, local_kv_shape, dtype=self.activation_dtype) ] * 2) def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_states = self._layer_stack( x, hparams.decoder_layers, encdec_attention_mask=encoder_attention_mask, step_num=step_num, encdec_tensors=encdec_tensors, states=states) logits = mtf.matmul(x, softmax_var) return logits, new_states if hparams.beam_size == 1: temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) return mtf.beam_search.greedy_decode(logits_fn, initial_ids, temperature=temperature, initial_states=initial_states, forced_ids=partial_targets, use_tpu=hparams.use_tpu) else: if hparams.transformer_type == "encdec": input_length = mtf.reduce_sum(mtf.to_float( mtf.cast(inputs, tf.bool)), reduced_dim=self.length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * hparams.decode_length_multiplier + hparams.decode_length_constant, tf.int32) else: decode_length = None beams, unused_scores = mtf.beam_search.beam_search( logits_fn, initial_ids, hparams.alpha, states=initial_states, decode_length=decode_length, use_tpu=hparams.use_tpu, dtype=self.activation_dtype) return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
def _layer_stack(self, x, layers, encoder_output=None, self_attention_mask=None, encdec_attention_mask=None, losses=None, step_num=None, encdec_tensors=None, states=None): """Encoder or decoder stack. Args: x: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim] layers: an list of strings encoder_output: an optional mtf.Tensor with shape [<batch_dims>, encoder_length_dim, model_dim] self_attention_mask: an optional mtf.Tensor with shape [batch, length_dim, memory_length_dim] containing values 0 or -inf. encdec_attention_mask: an optional mtf.Tensor with shape [batch, length_dim, encoder_length_dim] containing values 0 or -inf. losses: a list to be appended-to step_num: an optional mtf integer Scalar (used in incrmenental mode) encdec_tensors: an optional list of num_layers tuples, each of the form (q_var, o_var, k, v), (used in incremental mode) states: an optional list of Tensors (used in incremental mode) Returns: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim] Raises: ValueError: if hparams make no sense """ hparams = self._hparams is_incremental = (step_num is not None) def layer_prepostprocess_dropout(x): if is_incremental: return x return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape(self.batch_dims + [self.model_dim])) num_layers = len(layers) num_layer_norms = num_layers + 1 layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms) layer_norm_combined_var = mtf.get_variable( x.mesh, "layer_norm_scale", mtf.Shape([layer_norms_dim, self.model_dim]), initializer=tf.ones_initializer(), activation_dtype=x.dtype) layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim) def normalize(x): scale = layer_norm_vars.pop(0) variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim) return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale if is_incremental: states = list(states) new_states = [] tf.logging.info("states = %s" % (states, )) for lnum, layer_type in enumerate(layers): with tf.variable_scope("%s_%d" % (layer_type, lnum)): if layer_type == "att": # Self attention layer if is_incremental: y, new_k, new_v = mtf.layers.multihead_self_attention_incremental( normalize(x), prev_k=states.pop(0), prev_v=states.pop(0), step_num=step_num, master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="att") new_states.append(new_k) new_states.append(new_v) x += y else: x += layer_prepostprocess_dropout( mtf.layers.multihead_attention( normalize(x), None, self_attention_mask, self.kv_dim, self.heads_dim, dropout=hparams.attention_dropout, dropout_broadcast_dims=[self.length_dim], master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="att")) elif layer_type == "enc_att": # Encoder-Decoder attention layer if is_incremental: # Encoder-Decoder attention layer q_var, o_var, k, v = encdec_tensors[lnum] x += mtf.layers.multihead_encdec_attention_incremental( normalize(x), q_var, o_var, k, v, encdec_attention_mask, name="enc_att") else: x += layer_prepostprocess_dropout( mtf.layers.multihead_attention( normalize(x), encoder_output, encdec_attention_mask, self.kv_dim, self.heads_dim, dropout=hparams.attention_dropout, dropout_broadcast_dims=[self.length_dim], master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="enc_att")) elif layer_type == "local_att": if is_incremental: y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental( normalize(x), prev_k=states.pop(0), prev_v=states.pop(0), step_num=step_num, master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="local_att") new_states.append(new_k) new_states.append(new_v) x += y else: x += layer_prepostprocess_dropout( mtf.layers.masked_local_attention_1d( normalize(x), self.kv_dim, self.heads_dim, window_size=hparams. local_attention_window_size, master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, length_per_split=mtf. tensor_dim_to_size_per_split( hparams.layout, hparams.mesh_shape, self.max_length_dim), name="local_att")) else: if is_incremental: # insert length dimension. x_shape = x.shape shape_with_length = mtf.Shape( x_shape.dims[:-1] + [mtf.Dimension("length", 1)] + x_shape.dims[-1:]) x = mtf.reshape(x, shape_with_length) # ffn layer x += layer_prepostprocess_dropout( self._feedforward_layer(normalize(x), layer_type, losses=losses)) if is_incremental: # remove length dimension x = mtf.reshape(x, x_shape) x = layer_prepostprocess_dropout(normalize(x)) assert not layer_norm_vars if is_incremental: return x, new_states else: return x
def _mtf_model_fn(self, features, mesh): features = copy.copy(features) hparams = self._hparams targets = tf.to_int32(features["targets"]) if len(targets.get_shape()) > 2: tf.logging.info("targets = %s" % targets) targets = tf.squeeze(targets, [2, 3]) # pad targets to max_length def pad_to_max_length(x): extra_length = hparams.max_length - tf.shape(x)[1] x = tf.pad(x, [[0, 0], [0, extra_length]]) x = tf.reshape(x, [hparams.batch_size, hparams.max_length]) return x targets = pad_to_max_length(targets) for key in [ "targets_segmentation", "targets_position", "inputs_segmentation", "inputs_position" ]: if key in features: features[key] = pad_to_max_length(features[key]) shifted_targets = common_layers.shift_right_2d(targets) targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams) shifted_targets = self._import_to_batch_by_length( shifted_targets, "shifted_targets", mesh, hparams) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = self._import_to_batch_by_length( features["targets_segmentation"], "targets_segmentation", mesh, hparams) targets_position = self._import_to_batch_by_length( features["targets_position"], "targets_position", mesh, hparams) decoder_self_attention_mask = ( mtf.layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) + mtf.layers.attention_mask_same_segment( targets_segmentation, dtype=self.activation_dtype)) else: targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) def layer_prepostprocess_dropout(x): return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape(self.batch_dims + [self.model_dim])) extra_losses = [] (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "decoder": encoder_output = None encoder_decoder_attention_mask = None else: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = pad_to_max_length(inputs) inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = self._import_to_batch_by_length( features["inputs_segmentation"], "inputs_segmentation", mesh, hparams) inputs_position = self._import_to_batch_by_length( features["inputs_position"], "inputs_position", mesh, hparams) encoder_self_attention_mask = ( mtf.layers.attention_mask_same_segment( inputs_segmentation, dtype=self.activation_dtype)) else: inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) encoder_self_attention_mask = ( mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.gather(positional_embedding_var, inputs_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("encoder"): x = self._layer_stack( x, hparams.encoder_layers, self_attention_mask=encoder_self_attention_mask, losses=extra_losses) if hparams.transformer_type == "encdec": if "inputs_segmentation" in features: encoder_decoder_attention_mask = ( mtf.layers.attention_mask_same_segment( targets_segmentation, inputs_segmentation, dtype=self.activation_dtype)) else: encoder_decoder_attention_mask = encoder_self_attention_mask encoder_output = mtf.rename_dimension(x, self.length_dim.name, self.memory_length_dim.name) if hparams.transformer_type != "encoder": # DECODER x = (mtf.gather(targets_embedding_var, shifted_targets, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, targets_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("decoder"): x = self._layer_stack( x, hparams.decoder_layers, encoder_output=encoder_output, self_attention_mask=decoder_self_attention_mask, encdec_attention_mask=encoder_decoder_attention_mask, losses=extra_losses) logits = mtf.matmul(x, softmax_var) if hparams.mode == tf.estimator.ModeKeys.TRAIN: logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2) off_value = hparams.label_smoothing / self._targets_vocab_size on_value = 1.0 - hparams.label_smoothing + off_value soft_targets = mtf.one_hot(targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value, dtype=self.activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.targets_vocab_dim) weights = mtf.layers.weights_nonzero(targets, dtype=self.activation_dtype) loss = mtf.reduce_mean(loss * weights) for l in extra_losses: loss += l logits = mtf.to_float(logits) # combine batch dims if len(self.batch_dims) > 1: combined_batch_dim = mtf.Dimension(self.batch_dims[0].name, mtf.Shape(self.batch_dims).size) logits = mtf.reshape(logits, [combined_batch_dim] + logits.shape.dims[-2:]) return logits, loss
def lpt_init(lr_field, hr_field, a0, kvec_lr, kvec_hr, halo_size, hr_shape, lr_shape, part_shape, antialias=True, downsampling_factor=2, order=1, post_filtering=True, cosmology=Planck15): a = a0 batch_dim = hr_field.shape[0] lnc = lr_shape[-1].size k_dims_lr = [d.shape[0] for d in kvec_lr] k_dims_hr = [d.shape[0] for d in kvec_hr] k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]] k_dims_hr = [k_dims_hr[2], k_dims_hr[0], k_dims_hr[1]] # Create particles on the high resolution grid mstate = mesh_ops.mtf_indices(hr_field.mesh, shape=part_shape, dtype=tf.float32) X = mtf.einsum([mtf.ones(hr_field.mesh, [batch_dim]), mstate], output_shape=[batch_dim] + mstate.shape[:]) lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr) hr_kfield = mesh_utils.r2c3d(hr_field, k_dims_hr) grad_kfield_lr = mesh_kernels.apply_gradient_laplace_kernel( lr_kfield, kvec_lr) grad_kfield_hr = mesh_kernels.apply_gradient_laplace_kernel( hr_kfield, kvec_hr) # Reorder the low res FFTs which where transposed# y,z,x grad_kfield_lr = [grad_kfield_lr[2], grad_kfield_lr[0], grad_kfield_lr[1]] grad_kfield_hr = [grad_kfield_hr[2], grad_kfield_hr[0], grad_kfield_hr[1]] displacement = [] for f, g in zip(grad_kfield_lr, grad_kfield_hr): f = mesh_utils.c2r3d(f, lr_shape[-3:]) f = mtf.slicewise( lambda x: tf.expand_dims( tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1), [f], output_dtype=tf.float32, output_shape=mtf.Shape(hr_shape[0:4] + [ mtf.Dimension('sx_block', lnc // hr_shape[1].size), mtf.Dimension('sy_block', lnc // hr_shape[2].size), mtf.Dimension('sz_block', lnc // hr_shape[3].size) ]), name='my_reshape', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) for block_size_dim in hr_shape[-3:]: f = mtf.pad(f, [ halo_size // 2**downsampling_factor, halo_size // 2**downsampling_factor ], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]): f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim, halo_size // 2**downsampling_factor) f = mtf.reshape(f, f.shape + [mtf.Dimension('h_dim', 1)]) f = mesh_utils.upsample(f, downsampling_factor) f = mtf.reshape(f, f.shape[:-1]) g = mesh_utils.c2r3d(g, f.shape[-3:]) high_shape = g.shape # And now we remove the large scales g = mtf.reshape(g, g.shape + [mtf.Dimension('h_dim', 1)]) _low = mesh_utils.downsample(g, downsampling_factor, antialias=antialias) g = g - mtf.reshape(mesh_utils.upsample(_low, downsampling_factor), g.shape) g = mtf.reshape(g, high_shape) d = mesh_utils.cic_readout(f + g, X, halo_size) displacement.append(d) # Readout to particle positions displacement = mtf.stack([d for d in displacement], "ndim", axis=4) pt = PerturbationGrowth(cosmology, a=[a], a_normalize=1.0) DX = pt.D1(a) * displacement P = (a**2 * pt.f1(a) * pt.E(a)) * DX F = (a**2 * pt.E(a) * pt.gf(a) / pt.D1(a)) * DX # TODO: Implement 2nd order LPT # Moves the particles according to displacement X = X + DX return X, P, F
def force(state, lr_shape, hr_shape, kvec_lr, kvec_hr, halo_size, cosmology=Planck15, downsampling_factor=2, pm_nc_factor=1, antialias=True, **kwargs): """ Estimate force on the particles given a state. Parameters: ----------- state: tensor Input state tensor of shape (3, batch_size, npart, 3) boxsize: float Size of the simulation volume (Mpc/h) TODO: check units cosmology: astropy.cosmology Cosmology object pm_nc_factor: int TODO: @modichirag please add doc """ X, P, F = state #TODO: support different factor assert pm_nc_factor == 1 lnc = lr_shape[-1].size part_shape = X.shape k_dims_lr = [d.shape[0] for d in kvec_lr] k_dims_hr = [d.shape[0] for d in kvec_hr] # Reorder the FFTs which where transposed# y,z,x k_dims_lr = [k_dims_lr[2], k_dims_lr[0], k_dims_lr[1]] k_dims_hr = [k_dims_hr[2], k_dims_hr[0], k_dims_hr[1]] # Paint the particles on the high resolution mesh field = mtf.zeros(X.mesh, shape=hr_shape) for block_size_dim in hr_shape[-3:]: field = mtf.pad(field, [halo_size, halo_size], block_size_dim.name) field = mesh_utils.cic_paint(field, X, halo_size) for blocks_dim, block_size_dim in zip(hr_shape[1:4], field.shape[-3:]): field = mesh_ops.halo_reduce(field, blocks_dim, block_size_dim, halo_size) # Split the field into low and high resolution field = mtf.reshape(field, field.shape + [mtf.Dimension('h_dim', 1)]) high = field low = mesh_utils.downsample(field, downsampling_factor, antialias=True) low = mtf.reshape(low, low.shape[:-1]) hr_field = mtf.reshape(high, high.shape[:-1]) for block_size_dim in hr_shape[-3:]: low = mtf.slice(low, halo_size // 2**downsampling_factor, block_size_dim.size // 2**downsampling_factor, block_size_dim.name) # Hack usisng custom reshape because mesh is pretty dumb lr_field = mtf.slicewise(lambda x: x[:, 0, 0, 0], [low], output_dtype=tf.float32, output_shape=lr_shape, name='my_dumb_reshape', splittable_dims=lr_shape[:-1] + hr_shape[:4]) lr_kfield = mesh_utils.r2c3d(lr_field, k_dims_lr) hr_kfield = mesh_utils.r2c3d(hr_field, k_dims_hr) kfield_lr = mesh_kernels.apply_longrange_kernel(lr_kfield, kvec_lr, r_split=0) kfield_lr = mesh_kernels.apply_gradient_laplace_kernel(lr_kfield, kvec_lr) kfield_hr = mesh_kernels.apply_longrange_kernel(hr_kfield, kvec_hr, r_split=0) kfield_hr = mesh_kernels.apply_gradient_laplace_kernel(kfield_hr, kvec_hr) # Reorder the low res FFTs which where transposed# y,z,x kfield_lr = [kfield_lr[2], kfield_lr[0], kfield_lr[1]] kfield_hr = [kfield_hr[2], kfield_hr[0], kfield_hr[1]] displacement = [] for f, g in zip(kfield_lr, kfield_hr): f = mesh_utils.c2r3d(f, lr_shape[-3:]) f = mtf.slicewise( lambda x: tf.expand_dims( tf.expand_dims(tf.expand_dims(x, axis=1), axis=1), axis=1), [f], output_dtype=tf.float32, output_shape=mtf.Shape(hr_shape[0:4] + [ mtf.Dimension('sx_block', lnc // hr_shape[1].size), mtf.Dimension('sy_block', lnc // hr_shape[2].size), mtf.Dimension('sz_block', lnc // hr_shape[3].size) ]), name='my_reshape', splittable_dims=lr_shape[:-1] + hr_shape[1:4] + part_shape[1:3]) for block_size_dim in hr_shape[-3:]: f = mtf.pad(f, [ halo_size // 2**downsampling_factor, halo_size // 2**downsampling_factor ], block_size_dim.name) for blocks_dim, block_size_dim in zip(hr_shape[1:4], f.shape[-3:]): f = mesh_ops.halo_reduce(f, blocks_dim, block_size_dim, halo_size // 2**downsampling_factor) f = mtf.reshape(f, f.shape + [mtf.Dimension('h_dim', 1)]) f = mesh_utils.upsample(f, downsampling_factor) f = mtf.reshape(f, f.shape[:-1]) g = mesh_utils.c2r3d(g, f.shape[-3:]) high_shape = g.shape # And now we remove the large scales g = mtf.reshape(g, g.shape + [mtf.Dimension('h_dim', 1)]) _low = mesh_utils.downsample(g, downsampling_factor, antialias=antialias) g = g - mtf.reshape(mesh_utils.upsample(_low, downsampling_factor), g.shape) g = mtf.reshape(g, high_shape) d = mesh_utils.cic_readout(f + g, X, halo_size) displacement.append(d) # Readout the force to particle positions F = mtf.stack([d for d in displacement], "ndim", axis=4) F = F * 1.5 * cosmology.Om0 return X, P, F
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.activation_type # We assume fixed vocab size for targets targets = tf.to_int32(features["targets"]) # Image preprocessing, reshape into a 1D sequence and shift right. length = hparams.img_len*hparams.img_len*hparams.num_channels targets = tf.reshape(targets, [hparams.batch_size, length]) shifted_targets = common_layers.shift_right_2d(targets) # Declare all the dimensions batch_dim = mtf.Dimension("batch", hparams.batch_size) def import_to_batch_by_length(x, name): return mtf.import_tf_tensor( mesh, x, mtf.Shape([batch_dim, self.length_dim]), name=name) targets = import_to_batch_by_length(targets, "targets") shifted_targets = import_to_batch_by_length( shifted_targets, "shifted_targets") extra_losses = [] # Create targets content and position embeddings. # Create embedding var for targets and positions and do a gather. targets_embedding_var = mtf.get_variable( mesh, "targets_embedding", mtf.Shape([self.targets_vocab_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=activation_dtype) x = mtf.gather(targets_embedding_var, shifted_targets, self.targets_vocab_dim) # Add positional embeddings x += mtf.reshape(self.create_positional_emb_2d(targets), [self.length_dim, self.model_dim]) # If conditional and input is given, add the input embedding to the target. # TODO(nikip): Verify conditional. if self.has_input and not hparams.unconditional: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = import_to_batch_by_length(inputs, "inputs") # Input embeddings inputs_embedding_var = mtf.layers.embedding( mesh, "input_embedding", mtf.Shape([self.inputs_vocab_dim, self.model_dim]), activation_dtype=activation_dtype) inputs_emb = mtf.gather( inputs_embedding_var, inputs, self.inputs_vocab_dim) x += inputs_emb # Image Transformer Decoder # [ self attention - ffn - residual + dropout] x n if hparams.attention_type == "local1d_spatial": decoder_output = local_attention1d_spatial_decoder( x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) elif hparams.attention_type == "local2d_spatial": decoder_output = local_attention2d_spatial_decoder( x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) elif hparams.attention_type == "local1d": decoder_output = local_attention1d_masked_decoder( x, self.kv_dim, self.heads_dim, self.feedforward_dim, hparams) else: raise ValueError("Invalid attention type.") # Calculate the logits and loss. logits = mtf.layers.dense( decoder_output, self.outputs_vocab_dim, name="logits") # Need a reshape for logits logits = mtf.reshape( logits, mtf.Shape([batch_dim, self.length_dim, self.outputs_vocab_dim])) soft_targets = mtf.one_hot( targets, self.outputs_vocab_dim, dtype=activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.outputs_vocab_dim) loss = mtf.reduce_mean(loss) for l in extra_losses: loss += l # Reshape logits to original target shape. logits = mtf.reshape( logits, mtf.Shape([batch_dim, self.rows_dim, self.orig_cols_dim, self.channels_dim, self.outputs_vocab_dim])) return logits, loss
def _sample(self, features, mesh): hparams = self._hparams (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "encdec": inputs = features["inputs"] while len(inputs.shape.as_list()) > 2: inputs = tf.squeeze(inputs, axis=2) actual_batch_size = tf.shape(inputs)[0] actual_length = tf.shape(inputs)[1] inputs = tf.pad( inputs, [[0, hparams.batch_size - actual_batch_size], [0, hparams.max_length - actual_length]]) inputs = self._import_to_batch_by_length( inputs, "inputs", mesh, hparams) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.reshape(positional_embedding_var, mtf.Shape([self.length_dim, self.model_dim]))) encoder_attention_mask = ( mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.encoder_layers, self_attention_mask=encoder_attention_mask) encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) encdec_tensors = [] for layer_num, layer_type in enumerate(hparams.decoder_layers): if layer_type == "enc_att": with tf.variable_scope("decoder/enc_att_%d/enc_att" % layer_num): q_var, k_var, v_var, o_var = mtf.layers.multihead_attention_vars( mesh, self.heads_dim, self.model_dim, self.kv_dim, self.master_dtype, self.slice_dtype, self.activation_dtype) k = mtf.einsum( [encoder_output, k_var], mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim])) v = mtf.einsum( [encoder_output, v_var], mtf.Shape( self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim])) encdec_tensors.append((q_var, o_var, k, v)) else: encdec_tensors.append(None) partial_targets = None elif hparams.transformer_type == "decoder": encdec_tensors = None encoder_output = None encoder_attention_mask = None # Prepare partial targets. # In either features["inputs"] or features["targets"]. # We force the outputs to begin with these sequences. partial_targets = features.get("inputs", None) if partial_targets is None: partial_targets = features.get("targets", None) if partial_targets is not None: partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2) partial_targets = tf.to_int32(partial_targets) partial_targets_batch = tf.shape(partial_targets)[0] partial_targets_length = tf.shape(partial_targets)[1] partial_targets = tf.pad( partial_targets, [[0, hparams.batch_size - partial_targets_batch], [0, hparams.max_length - partial_targets_length]]) partial_targets = self._import_to_batch_by_length( partial_targets, "partial_targets", mesh, hparams) else: raise ValueError( "hparams.model_type = %s not yet supported" % hparams.transformer_type) local_attention_window = mtf.Dimension( "local_attention_window", hparams.local_attention_window_size) if hparams.beam_size == 1: ids_shape = mtf.Shape(self.batch_dims + [self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape(self.batch_dims + [self.heads_dim, local_attention_window, self.kv_dim]) else: beam_dim = mtf.Dimension("beam", hparams.beam_size) ids_shape = mtf.Shape(self.batch_dims + [beam_dim, self.length_dim]) kv_shape = mtf.Shape(self.batch_dims + [beam_dim, self.heads_dim, self.memory_length_dim, self.kv_dim]) local_kv_shape = mtf.Shape(self.batch_dims + [beam_dim, self.heads_dim, local_attention_window, self.kv_dim]) initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32) initial_states = [] for layer in hparams.decoder_layers: if layer == "att": initial_states.extend( [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)] * 2) elif layer == "local_att": initial_states.extend( [mtf.zeros(mesh, local_kv_shape, dtype=self.activation_dtype)] * 2) def logits_fn(step_num, ids, states): """Produce logits for this step, and new states.""" ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim) x = (mtf.gather(targets_embedding_var, ids_this_step, self.targets_vocab_dim) + mtf.gather(positional_embedding_var, step_num, self.max_length_dim)) with tf.variable_scope("decoder"): x, new_states = self._layer_stack( x, hparams.decoder_layers, encdec_attention_mask=encoder_attention_mask, step_num=step_num, encdec_tensors=encdec_tensors, states=states) logits = mtf.matmul(x, softmax_var) return logits, new_states if hparams.beam_size == 1: temperature = (0.0 if hparams.sampling_method == "argmax" else hparams.sampling_temp) return mtf.beam_search.greedy_decode( logits_fn, initial_ids, temperature=temperature, initial_states=initial_states, forced_ids=partial_targets, use_tpu=hparams.use_tpu) else: if hparams.transformer_type == "encdec": input_length = mtf.reduce_sum( mtf.to_float(mtf.cast(inputs, tf.bool)), reduced_dim=self.length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * hparams.decode_length_multiplier + hparams.decode_length_constant, tf.int32) else: decode_length = None beams, unused_scores = mtf.beam_search.beam_search( logits_fn, initial_ids, hparams.alpha, states=initial_states, decode_length=decode_length, use_tpu=hparams.use_tpu, dtype=self.activation_dtype) return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
def _layer_stack(self, x, layers, encoder_output=None, self_attention_mask=None, encdec_attention_mask=None, losses=None, step_num=None, encdec_tensors=None, states=None): """Encoder or decoder stack. Args: x: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim] layers: an list of strings encoder_output: an optional mtf.Tensor with shape [<batch_dims>, encoder_length_dim, model_dim] self_attention_mask: an optional mtf.Tensor with shape [batch, length_dim, memory_length_dim] containing values 0 or -inf. encdec_attention_mask: an optional mtf.Tensor with shape [batch, length_dim, encoder_length_dim] containing values 0 or -inf. losses: a list to be appended-to step_num: an optional mtf integer Scalar (used in incrmenental mode) encdec_tensors: an optional list of num_layers tuples, each of the form (q_var, o_var, k, v), (used in incremental mode) states: an optional list of Tensors (used in incremental mode) Returns: a mtf.Tensor with shape [<batch_dims>, length_dim, model_dim] Raises: ValueError: if hparams make no sense """ hparams = self._hparams is_incremental = (step_num is not None) def layer_prepostprocess_dropout(x): if is_incremental: return x return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape(self.batch_dims + [self.model_dim])) num_layers = len(layers) num_layer_norms = num_layers + 1 layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms) layer_norm_combined_var = mtf.get_variable( x.mesh, "layer_norm_scale", mtf.Shape([layer_norms_dim, self.model_dim]), initializer=tf.ones_initializer(), activation_dtype=x.dtype) layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim) def normalize(x): scale = layer_norm_vars.pop(0) variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim) return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale if is_incremental: states = list(states) new_states = [] tf.logging.info("states = %s" % (states,)) for lnum, layer_type in enumerate(layers): with tf.variable_scope("%s_%d" % (layer_type, lnum)): if layer_type == "att": # Self attention layer if is_incremental: y, new_k, new_v = mtf.layers.multihead_self_attention_incremental( normalize(x), prev_k=states.pop(0), prev_v=states.pop(0), step_num=step_num, master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="att") new_states.append(new_k) new_states.append(new_v) x += y else: x += layer_prepostprocess_dropout( mtf.layers.multihead_attention( normalize(x), None, self_attention_mask, self.kv_dim, self.heads_dim, dropout=hparams.attention_dropout, dropout_broadcast_dims=[self.length_dim], master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="att")) elif layer_type == "enc_att": # Encoder-Decoder attention layer if is_incremental: # Encoder-Decoder attention layer q_var, o_var, k, v = encdec_tensors[lnum] x += mtf.layers.multihead_encdec_attention_incremental( normalize(x), q_var, o_var, k, v, encdec_attention_mask, name="enc_att") else: x += layer_prepostprocess_dropout( mtf.layers.multihead_attention( normalize(x), encoder_output, encdec_attention_mask, self.kv_dim, self.heads_dim, dropout=hparams.attention_dropout, dropout_broadcast_dims=[self.length_dim], master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="enc_att")) elif layer_type == "local_att": if is_incremental: y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental( normalize(x), prev_k=states.pop(0), prev_v=states.pop(0), step_num=step_num, master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="local_att") new_states.append(new_k) new_states.append(new_v) x += y else: x += layer_prepostprocess_dropout( mtf.layers.masked_local_attention_1d( normalize(x), self.kv_dim, self.heads_dim, window_size=hparams.local_attention_window_size, master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, length_per_split=mtf.tensor_dim_to_size_per_split( hparams.layout, hparams.mesh_shape, self.max_length_dim), name="local_att")) elif layer_type == "compressed_att": if is_incremental: raise ValueError("compressed_att incremental not implemented") else: x += layer_prepostprocess_dropout( mtf.layers.multihead_self_attention_memory_compressed( normalize(x), mask_right=True, compression_factor=hparams.compression_factor, kv_channels=self.kv_dim, heads=self.heads_dim, dropout=hparams.attention_dropout, dropout_broadcast_dims=[self.length_dim], master_dtype=self.master_dtype, slice_dtype=self.slice_dtype, name="compressed_att")) else: if is_incremental: # insert length dimension. x_shape = x.shape shape_with_length = mtf.Shape( x_shape.dims[:-1] + [mtf.Dimension("length", 1)] + x_shape.dims[-1:]) x = mtf.reshape(x, shape_with_length) # ffn layer x += layer_prepostprocess_dropout( self._feedforward_layer(normalize(x), layer_type, losses=losses)) if is_incremental: # remove length dimension x = mtf.reshape(x, x_shape) x = layer_prepostprocess_dropout(normalize(x)) assert not layer_norm_vars if is_incremental: return x, new_states else: return x
def transformer_moe_layer_v2(inputs, output_dim, hparams, train, variable_dtype, layout=None, mesh_shape=None, nonpadding=None): """2-level mixture of experts. Adapted from the paper https://arxiv.org/abs/1701.06538 Note: until the algorithm and inferface solidify, we pass in a hyperparameters dictionary in order not to complicate the interface in mtf_transformer.py . Once this code moves out of "research", we should pass the hyperparameters separately. Hyperparameters used: hparams.moe_num_experts: number of experts hparams.moe_hidden_size: size of hidden layer in each expert hparams.moe_group_size: size of each "group" for gating purposes hparams.moe_capacity_factor_train: a float hparams.moe_capacity_factor_eval: a float hparams.moe_capacity_factor_second_level: a float hparams.moe_gating: a string + all hyperparmeters used by _top_2_gating() One set of params for experts in first level and different of hparams per expert in the second level. The number of parameters in the gating network is: (input_dim.size * (hparams.num_experts) + (moe_hidden_size * hparams.num_experts) * hparams.num_experts The number of parameters in the experts themselves is: (hparams.num_experts * (input_dim.size + output_dim.size) * hparams.moe_hidden_size) The input is n-dimensional: [<batch_and_length_dims>, input_dim], consisting of the representations of all positions in a batch of sequences. Each position of each sequence is sent to 0-3 experts. The expert choices and the combination weights are determined by a learned gating function. This function returns a small auxiliary loss that should be added to the training loss of the model. This loss helps to balance expert usage. Without the loss, it is very likely that a few experts will be trained and the rest will starve. Several hacks are necessary to get around current TPU limitations: - To ensure static shapes, we enforce (by truncation/padding) that each sequence send the same number of elements to each expert. It would make more sense to enforce this equality over the entire batch, but due to our hacked-up gather-by-matmul implementation, we need to divide the batch into "groups". For each group, the same number of elements are sent to each expert. TODO(noam): Factor this code better. We want to be able to substitute different code for the experts themselves. Dimensions cheat sheet: a, b: batch size l: original sequence length m: input depth n: output depth g, h: number of groups s, t: group size x, y: number of experts c, d: expert capacity input: [a0, b1, l, m] input: [a0, g1, s, m] dispatch_tensor_x: [a0, g1, s, x, c] expert_input: [a0, g1, x, c, m] alltoall: [a0, g, x1, c, m] alltoall: [a0, g, x1, c, m] transpose: [x1, a0, g, c, m] reshape: [x1, h0, s, m] assignment2: [x1, h0, t, y, d] expert_input2: [x1, h0, y, d, m] alltoall: [x1, h, y0, d, m] ... reverse of that gating params 0: [m, x] gating params 1: [x1, m, y] expert params: [x1, y0, m, hidden] [x1, y0, hidden, n] Args: inputs: a mtf.Tensor with shape [a, b, l, m] output_dim: a mtf.Dimension (for Transformer, this is input_dim) hparams: model hyperparameters train: a boolean variable_dtype: a mtf.VariableDType layout: optional - an input to mtf.convert_to_layout_rules mesh_shape: optional - an input to mtf.convert_to_shape nonpadding: an optional mtf.Tensor with shape [a, b, l] and the same dtype as inputs, consisting of ones(nonpadding) and zeros(padding). Returns: outputs: a Tensor with shape [a, b, l, n] loss: a mtf scalar Raises: ValueError: on unrecognized hparams.moe_gating """ if nonpadding is not None: nonpadding = mtf.zeros(inputs.mesh, inputs.shape.dims[:-1], dtype=inputs.dtype) + nonpadding insert_outer_batch_dim = (len(inputs.shape.dims) == 3) if insert_outer_batch_dim: inputs = mtf.reshape(inputs, [mtf.Dimension("outer_batch", 1)] + inputs.shape.dims) assert len(hparams.moe_num_experts) == 2 a0, b1, l, m = inputs.shape.dims hidden_dim = mtf.Dimension("expert_hidden", hparams.moe_hidden_size) x1 = mtf.Dimension("expert_x", hparams.moe_num_experts[0]) y0 = mtf.Dimension("expert_y", hparams.moe_num_experts[1]) x = mtf.Dimension("expert_x_unsplit", hparams.moe_num_experts[0]) y = mtf.Dimension("expert_y_unsplit", hparams.moe_num_experts[1]) n = output_dim # We "cheat" here and look at the mesh shape and layout. This is to ensure # that the number of groups (g.size) is a multiple of the mesh dimension # over which those groups are split. num_groups, group_size = _split_into_groups( b1.size * l.size, hparams.moe_group_size, mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, b1)) g1 = mtf.Dimension(b1.name, num_groups) g = mtf.Dimension(b1.name + "_unsplit", g1.size) s = mtf.Dimension("group_size_x", group_size) # Each sequence sends (at most?) expert_capacity positions to each expert. # Static expert_capacity dimension is needed for expert batch sizes if train: capacity_factor = hparams.moe_capacity_factor_train else: capacity_factor = hparams.moe_capacity_factor_eval expert_capacity = min(s.size, int((s.size * capacity_factor) / x.size)) expert_capacity = max(expert_capacity, 4) c = mtf.Dimension("expert_capacity_x", expert_capacity) # We "cheat" here and look at the mesh shape and layout. This is to ensure # that the number of groups (h.size) is a multiple of the mesh dimension # over which those groups are split. num_groups, group_size = _split_into_groups( a0.size * g.size * c.size, hparams.moe_group_size, mtf.tensor_dim_to_mesh_dim_size(layout, mesh_shape, a0)) t = mtf.Dimension("group_size_y", group_size) h0 = mtf.Dimension(a0.name, num_groups) h = mtf.Dimension(a0.name + "_unsplit", h0.size) expert_capacity = min( t.size, int((t.size * hparams.moe_capacity_factor_second_level) / y.size)) expert_capacity = max(expert_capacity, 4) d = mtf.Dimension("expert_capacity_y", expert_capacity) # First level of expert routing # Reshape the inner batch size to a multiple of group_dim g1 and # group_size_dim s. inputs = mtf.reshape(inputs, [a0, g1, s, m]) if nonpadding is not None: nonpadding = mtf.reshape(nonpadding, [a0, g1, s]) # Get the assignments for the first level. # dispatch_tensor_x has shape [a0, g1, s, x, c] if hparams.moe_gating == "top_2": dispatch_tensor_x, combine_tensor_x, loss_outer = _top_2_gating( inputs=inputs, outer_expert_dims=None, experts_dim=x, expert_capacity_dim=c, hparams=hparams, train=train, variable_dtype=variable_dtype, name="outer_gating", importance=nonpadding) else: raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating) # Now create expert_inputs based on the assignments. # put num_experts dimension first to make split easier in alltoall expert_inputs_x = mtf.einsum([inputs, dispatch_tensor_x], [x, a0, g1, c, m]) # we construct an "importance" Tensor for the inputs to the second-level # gating. The importance of an input is 1.0 if it represents the # first-choice expert-group and 0.5 if it represents the second-choice expert # group. This is used by the second-level gating. importance = mtf.reduce_sum(combine_tensor_x, output_shape=[x, a0, g1, c]) importance = 0.5 * (mtf.to_float(mtf.greater(importance, 0.5)) + mtf.to_float(mtf.greater(importance, 0.0))) # First level, all to all. Here we change the split dimension from g1 to x1. expert_inputs_x = mtf.reshape(expert_inputs_x, mtf.Shape([x1, a0, g, c, m])) importance = mtf.reshape(importance, [x1, a0, g, c]) # Second level of expert routing # Reshape the expert_inputs outer batch dim to be a multiple of group_dim h0 # and group_size_dim t. inputs_y = mtf.reshape(expert_inputs_x, [x1, h0, t, m]) importance = mtf.reshape(importance, [x1, h0, t]) # Get the assignments for the second level. # dispatch_tensor_y has shape [x1, h0, t, y, d] if hparams.moe_gating == "top_2": dispatch_tensor_y, combine_tensor_y, loss_inner = _top_2_gating( inputs=inputs_y, outer_expert_dims=[x1], experts_dim=y, expert_capacity_dim=d, hparams=hparams, train=train, variable_dtype=variable_dtype, importance=importance, name="inner_gating") else: raise ValueError("unknown hparams.moe_gating=%s" % hparams.moe_gating) # Now create expert_inputs based on the assignments. # put num_experts dimension first to make split easier in alltoall expert_inputs_y = mtf.einsum([inputs_y, dispatch_tensor_y], [y, x1, h0, d, m]) # Second level, all to all. Here we change the split dimension from h0 to y0. expert_inputs_y = mtf.reshape(expert_inputs_y, mtf.Shape([y0, x1, h, d, m])) hidden_output = mtf.layers.dense(expert_inputs_y, hidden_dim, expert_dims=[y0, x1], activation=mtf.relu, use_bias=False, variable_dtype=variable_dtype, name="wi") expert_output = mtf.layers.dense(hidden_output, output_dim, expert_dims=[y0, x1], use_bias=False, variable_dtype=variable_dtype, name="wo") # NOW COMBINE EXPERT OUTPUTS (reversing everything we have done) # expert_output has shape [y0, x1, h, d, n] # alltoall expert_output = mtf.reshape(expert_output, mtf.Shape([y, x1, h0, d, n])) # combine results from inner level output_y = mtf.einsum([expert_output, combine_tensor_y], [x1, h0, t, n]) # Reshape the combined tensor from inner level to now contain outer_batch_dim # a0 and group_dim g output = mtf.reshape(output_y, [x1, a0, g, c, n]) # alltoall from expert_dim x to group_dim g1 expert_output_x = mtf.reshape(output, mtf.Shape([x, a0, g1, c, n])) # combine results from outer level output_x = mtf.einsum([expert_output_x, combine_tensor_x], [a0, g1, s, n]) # Reshape the combined tensor to now contain inner_batch_dim # b1 and the original sequence length output = mtf.reshape(output_x, [a0, b1, l, n]) if insert_outer_batch_dim: output = mtf.reshape(output, [b1, l, n]) return output, (loss_outer + loss_inner) * hparams.moe_loss_coef
def _mtf_model_fn(self, features, mesh): self._original_features = features features = copy.copy(features) hparams = self._hparams extra_losses = [] targets = tf.to_int32(features["targets"]) if len(targets.get_shape()) > 2: tf.logging.info("targets = %s" % targets) targets = tf.squeeze(targets, [2, 3]) # pad targets to max_length def pad_to_max_length(x): extra_length = hparams.max_length - tf.shape(x)[1] x = tf.pad(x, [[0, 0], [0, extra_length]]) x = tf.reshape(x, [hparams.batch_size, hparams.max_length]) return x targets = pad_to_max_length(targets) targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams) for key in ["targets_segmentation", "targets_position", "inputs_segmentation", "inputs_position"]: if key in features: features[key] = pad_to_max_length(features[key]) if hparams.decoder_type == "autoregressive": shifted_targets = mtf.shift( targets, offset=1, dim=self.length_dim, wrap=False) elif hparams.decoder_type == "denoising": shifted_targets = self._noisy_targets(targets, extra_losses) else: raise ValueError( "unknown hparams.decoder_type = %s" % hparams.decoder_type) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = self._import_to_batch_by_length( features["targets_segmentation"], "targets_segmentation", mesh, hparams) targets_position = self._import_to_batch_by_length( features["targets_position"], "targets_position", mesh, hparams) decoder_self_attention_mask = mtf.layers.attention_mask_same_segment( targets_segmentation, dtype=self.activation_dtype) if hparams.decoder_type == "autoregressive": decoder_self_attention_mask += mtf.layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) else: targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) if hparams.decoder_type == "autoregressive": decoder_self_attention_mask = mtf.layers.attention_mask_autoregressive( targets_position, dtype=self.activation_dtype) else: decoder_self_attention_mask = None def layer_prepostprocess_dropout(x): return mtf.dropout( x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout, noise_shape=mtf.Shape(self.batch_dims + [self.model_dim])) (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var) = self._embedding_and_softmax_vars(mesh) if hparams.transformer_type == "decoder": encoder_output = None encoder_decoder_attention_mask = None else: inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3]) inputs = pad_to_max_length(inputs) inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = self._import_to_batch_by_length( features["inputs_segmentation"], "inputs_segmentation", mesh, hparams) inputs_position = self._import_to_batch_by_length( features["inputs_position"], "inputs_position", mesh, hparams) encoder_self_attention_mask = ( mtf.layers.attention_mask_same_segment( inputs_segmentation, dtype=self.activation_dtype)) else: inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32) encoder_self_attention_mask = ( mtf.layers.attention_mask_ignore_padding( inputs, dtype=self.activation_dtype)) x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) + mtf.gather(positional_embedding_var, inputs_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("encoder"): x = self._layer_stack(x, hparams.encoder_layers, self_attention_mask=encoder_self_attention_mask, losses=extra_losses) if hparams.transformer_type == "encdec": if "inputs_segmentation" in features: encoder_decoder_attention_mask = ( mtf.layers.attention_mask_same_segment( targets_segmentation, inputs_segmentation, dtype=self.activation_dtype)) else: encoder_decoder_attention_mask = encoder_self_attention_mask encoder_output = mtf.rename_dimension( x, self.length_dim.name, self.memory_length_dim.name) if hparams.transformer_type != "encoder": # DECODER x = (mtf.gather( targets_embedding_var, shifted_targets, self.targets_vocab_dim) + mtf.gather( positional_embedding_var, targets_position, self.max_length_dim)) x = layer_prepostprocess_dropout(x) with tf.variable_scope("decoder"): x = self._layer_stack( x, hparams.decoder_layers, encoder_output=encoder_output, self_attention_mask=decoder_self_attention_mask, encdec_attention_mask=encoder_decoder_attention_mask, losses=extra_losses) if (hparams.reshape_logits_hack and hparams.mode == tf.estimator.ModeKeys.TRAIN): # For some reason, the logits computation is extremely slow on TPU # in some cases where the batch size per core is 1. Reshape the logits # and the targets to double the batch size and halve the length. # TODO(noam): file a bug. old_dims = self.batch_dims + [self.length_dim] new_dims = self.batch_dims[:-1] + [ mtf.Dimension(self.batch_dims[-1].name, self.batch_dims[-1].size * 2), mtf.Dimension(self.length_dim.name, self.length_dim.size // 2)] x = mtf.reshape(x, new_dims + [self.model_dim]) targets = mtf.reshape(targets, new_dims) logits = mtf.matmul(x, softmax_var) if hparams.mode == tf.estimator.ModeKeys.TRAIN: logits = mtf.layers.multiplicative_jitter(logits, epsilon=1e-2) off_value = hparams.label_smoothing / self._targets_vocab_size on_value = 1.0 - hparams.label_smoothing + off_value soft_targets = mtf.one_hot( targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value, dtype=self.activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, self.targets_vocab_dim) weights = mtf.layers.weights_nonzero(targets, dtype=self.activation_dtype) loss = mtf.reduce_mean(loss * weights) for l in extra_losses: loss += l if (hparams.reshape_logits_hack and hparams.mode == tf.estimator.ModeKeys.TRAIN): logits = mtf.reshape(logits, old_dims + [self.targets_vocab_dim]) logits = mtf.to_float(logits) return logits, loss
def mtf_model_fn(self, features, mesh): features = copy.copy(features) tf.logging.info("features = %s" % features) hparams = self._hparams activation_dtype = self.set_activation_type() is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN # Declare all the dimensions batch_dim = mtf.Dimension("batch", hparams.batch_size) hidden_dim = mtf.Dimension("hidden", hparams.hidden_size) filter_h_dim = mtf.Dimension("filter_height", 7) filter_w_dim = mtf.Dimension("filter_width", 7) filters = mtf.Dimension("filters", hparams.filter_sizes[0]) rows_dim = mtf.Dimension("rows_size", hparams.rows_size) cols_dim = mtf.Dimension("cols_size", hparams.cols_size) row_blocks_dim = mtf.Dimension("row_blocks", hparams.row_blocks) col_blocks_dim = mtf.Dimension("col_blocks", hparams.col_blocks) classes_dim = mtf.Dimension("classes", 10) channels_dim = mtf.Dimension("channels", 3) one_channel_dim = mtf.Dimension("one_channel", 1) inputs = features["inputs"] x = mtf.import_tf_tensor( mesh, tf.reshape(inputs, [ hparams.batch_size, hparams.row_blocks, hparams.rows_size // hparams.row_blocks, hparams.col_blocks, hparams.num_channels*hparams.cols_size // hparams.col_blocks, hparams.num_channels]), mtf.Shape( [batch_dim, row_blocks_dim, rows_dim, col_blocks_dim, cols_dim, channels_dim])) x = mtf.transpose(x, [batch_dim, row_blocks_dim, col_blocks_dim, rows_dim, cols_dim, channels_dim]) x = mtf.to_float(x) initial_filters = mtf.get_variable( mesh, "init_filters", mtf.Shape([filter_h_dim, filter_w_dim, channels_dim, filters])) x = mtf.conv2d_with_blocks( x, initial_filters, strides=[1, 1, 1, 1], padding="SAME", h_blocks_dim=None, w_blocks_dim=col_blocks_dim) x = batch_norm_relu(x, is_training) # Conv blocks # [block - strided block layer - strided block layer] x n for layer in range(hparams.num_layers): layer_name = "block_layer_%d" % layer with tf.variable_scope(layer_name): # Residual block layer x = block_layer( inputs=x, filters=hparams.filter_sizes[0], blocks=hparams.layer_sizes[0], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer1", row_blocks_dim=None, col_blocks_dim=None) x = block_layer( inputs=x, filters=hparams.filter_sizes[1], blocks=hparams.layer_sizes[1], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer2", row_blocks_dim=None, col_blocks_dim=None) x = block_layer( inputs=x, filters=hparams.filter_sizes[2], blocks=hparams.layer_sizes[2], strides=[1, 1, 1, 1], is_training=is_training, name="block_layer3", row_blocks_dim=None, col_blocks_dim=None) # Calculate the logits and loss. out = x outputs = mtf.layers.dense( out, hidden_dim, reduced_dims=out.shape.dims[-5:], activation=mtf.relu, name="dense") # We assume fixed vocab size for targets labels = tf.squeeze(tf.to_int32(features["targets"]), [2, 3]) labels = mtf.import_tf_tensor( mesh, tf.reshape(labels, [hparams.batch_size]), mtf.Shape([batch_dim])) logits = mtf.layers.dense(outputs, classes_dim, name="logits") soft_targets = mtf.one_hot(labels, classes_dim, dtype=activation_dtype) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, classes_dim) # Reshape logits so it doesn't break inside t2t. logits = mtf.reshape( logits, mtf.Shape([batch_dim, one_channel_dim, classes_dim])) loss = mtf.reduce_mean(loss) return logits, loss