def transformer_layers_sharded(dp, ps_devices, inputs, num_layers, hparams, self_attention_bias=None, enc_output=None, attention_type=AttentionType.GLOBAL, name="transformer"): """Multi layer transformer, sharded by the data parallelism dp.""" x = inputs extra_loss = tf.constant(0.0) moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")] expert_fn = expert_utils.ffn_expert_fn( hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout) for layer in xrange(num_layers): with tf.variable_scope("%s_layer_%d" % (name, layer)): # self-attention if attention_type == AttentionType.LOCAL_2D: y = dp(local_attention_2d(common_layers.layer_preprocess(x, hparams), hparams, attention_type="masked_local_attention_2d")) elif attention_type == AttentionType.LOCAL_1D: y = dp(local_attention_1d(common_layers.layer_preprocess(x, hparams), hparams, attention_type="local_mask_right", q_padding="LEFT", kv_padding="LEFT")) elif attention_type == AttentionType.GLOCAL: y = dp(local_global_attention( common_layers.layer_preprocess(x, hparams), self_attention_bias, hparams, q_padding="LEFT", kv_padding="LEFT")) elif attention_type == AttentionType.GLOBAL: self_attention_bias = dp(get_self_attention_bias(x)) y = dp(full_self_attention(common_layers.layer_preprocess(x, hparams), self_attention_bias, hparams, q_padding="LEFT", kv_padding="LEFT")) x = common_layers.layer_postprocess(x, y, hparams) if enc_output is not None: y = dp(encdec_attention_1d(common_layers.layer_preprocess(x, hparams), enc_output, None, hparams)) x = dp(common_layers.layer_postprocess, x, y, hparams) with tf.variable_scope("ffn"): if str(layer) in hparams.moe_layers_decoder.split(","): y, loss = expert_utils.distributed_moe( dp, ps_devices, common_layers.layer_preprocess(x, hparams), hparams.mode == tf.estimator.ModeKeys.TRAIN, input_size=hparams.hidden_size, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) extra_loss += loss x = dp(common_layers.layer_postprocess, x, y, hparams) else: y = dp(ffn_layer, common_layers.layer_preprocess(x, hparams), hparams) x = dp(common_layers.layer_postprocess, x, y, hparams) return dp(common_layers.layer_preprocess, x, hparams), extra_loss
def transformer_layers_sharded(dp, ps_devices, inputs, num_layers, hparams, self_attention_bias=None, enc_output=None, attention_type=AttentionType.GLOBAL, name="transformer"): """Multi layer transformer, sharded by the data parallelism dp.""" x = inputs extra_loss = tf.constant(0.0) moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")] expert_fn = expert_utils.ffn_expert_fn( hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout) for layer in range(num_layers): with tf.variable_scope("%s_layer_%d" % (name, layer)): # self-attention if attention_type == AttentionType.LOCAL_2D: y = dp(local_attention_2d(common_layers.layer_preprocess(x, hparams), hparams, attention_type="masked_local_attention_2d")) elif attention_type == AttentionType.LOCAL_1D: y = dp(local_attention_1d(common_layers.layer_preprocess(x, hparams), hparams, attention_type="local_mask_right", q_padding="LEFT", kv_padding="LEFT")) elif attention_type == AttentionType.GLOCAL: y = dp(local_global_attention( common_layers.layer_preprocess(x, hparams), self_attention_bias, hparams, q_padding="LEFT", kv_padding="LEFT")) elif attention_type == AttentionType.GLOBAL: self_attention_bias = dp(get_self_attention_bias(x)) y = dp(full_self_attention(common_layers.layer_preprocess(x, hparams), self_attention_bias, hparams, q_padding="LEFT", kv_padding="LEFT")) x = common_layers.layer_postprocess(x, y, hparams) if enc_output is not None: y = dp(encdec_attention_1d(common_layers.layer_preprocess(x, hparams), enc_output, None, hparams)) x = dp(common_layers.layer_postprocess, x, y, hparams) with tf.variable_scope("ffn"): if str(layer) in hparams.moe_layers_decoder.split(","): y, loss = expert_utils.distributed_moe( dp, ps_devices, common_layers.layer_preprocess(x, hparams), hparams.mode == tf.estimator.ModeKeys.TRAIN, input_size=hparams.hidden_size, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) extra_loss += loss x = dp(common_layers.layer_postprocess, x, y, hparams) else: y = dp(ffn_layer, common_layers.layer_preprocess(x, hparams), hparams) x = dp(common_layers.layer_postprocess, x, y, hparams) return dp(common_layers.layer_preprocess, x, hparams), extra_loss
def build_a_layer(layer_input, layer_type, is_training, config): """Build a single layer.""" batch_size = tf.shape(layer_input)[0] net = common_layers.layer_norm(layer_input) if layer_type == 'att': net = common_attention.multihead_attention( query_antecedent=net, memory_antecedent=None, bias=None, total_key_depth=config['hidden_size'], total_value_depth=config['hidden_size'], output_depth=config['hidden_size'], num_heads=config['attention_heads'], dropout_rate=config['attention_dropout'], attention_type=config['attention_type']) elif layer_type == 'conv': if config['preactivate']: if config['activation'] == 'relu': net = tf.nn.relu(net) elif config['activation'] == 'swish': net = swish(net) elif config['activation'] == 'prelu': net = parametric_relu(net) net = separable_conv(net, config['kernel_size'], activation=None) else: if config['activation'] == 'relu': net = separable_conv(net, config['kernel_size'], activation=tf.nn.relu) elif config['activation'] == 'swish': net = separable_conv(net, config['kernel_size'], activation=swish) elif config['activation'] == 'prelu': net = separable_conv(net, config['kernel_size'], activation=parametric_relu) elif layer_type == 'ffn': net = tf.reshape(net, [-1, config['hidden_size']]) net = expert_utils.ffn_expert_fn( input_size=config['hidden_size'], hidden_sizes=[config['ffn_hs_factor'] * config['hidden_size']], output_size=config['hidden_size'], )(net) net = tf.reshape(net, [batch_size, -1, config['hidden_size']]) else: raise ValueError('Unknown layer type %s' % layer_type) if config['layer_output_dropout'] > 0.0 and is_training: net = tf.nn.dropout(net, 1.0 - config['layer_output_dropout']) return net
def conv_experts(xs, hparams, dp, ps, padding, mask, layer_id): """Convolutions + Mixture-of-Experts layer.""" del layer_id # Unused. train = hparams.mode == tf.estimator.ModeKeys.TRAIN, conv_out = dp(conv_res_step, xs, hparams, padding, mask) loss = 0.0 moe_hidden_sizes = [hparams.filter_size] expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) moe_out, loss = expert_utils.distributed_moe( dp, ps, xs, train, input_size=hparams.hidden_size, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=1.0) return dp(residual_fn3, xs, moe_out, conv_out, hparams), loss
def conv_experts(xs, hparams, dp, ps, padding, mask, layer_id): """Convolutions + Mixture-of-Experts layer.""" del layer_id # Unused. train = hparams.mode == tf.estimator.ModeKeys.TRAIN, conv_out = dp(conv_res_step, xs, hparams, padding, mask) loss = 0.0 moe_hidden_sizes = [hparams.filter_size] expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) moe_out, loss = expert_utils.distributed_moe( dp, ps, xs, train, input_size=hparams.hidden_size, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=1.0) return dp(residual_fn3, xs, moe_out, conv_out, hparams), loss
def body_sharded(self, sharded_features): # Remove dropout if not training hparams = self._hparams dp = self._data_parallelism if hparams.use_inputs: decoder_input = dp(tf.squeeze, sharded_features["inputs"], 2) decoder_self_attention_bias = None else: targets = sharded_features["targets"] targets = dp(tf.squeeze, targets, 2) (decoder_input, decoder_self_attention_bias, pad_remover) = dp( attention_lm_moe_prepare_decoder, targets, hparams) def preprocess(x): return dp(common_layers.layer_preprocess, x, hparams) def postprocess(x, y): return dp(common_layers.layer_postprocess, x, y, hparams) x = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) extra_loss = 0.0 moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")] if hparams.diet_experts: hsize, = moe_hidden_sizes def _diet_expert(x): return diet.diet_expert(x, hsize, diet.diet_adam_optimizer_params()) expert_fn = _diet_expert else: expert_fn = expert_utils.ffn_expert_fn( hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) if not hparams.use_inputs: # As preprocess and postprocess are called with batch of size one (all # batches concatenated), we just make sure that batch_norm is not use ( # should not either way) assert hparams.norm_type != "batch" tf.logging.info("Applying Padding Remover for the attention experts") dp_remove_pad = functools.partial( dp, remove_pad, pad_remover=pad_remover, mode=hparams.mode) dp_restore_pad = functools.partial( dp, restore_pad, ref_x=x, pad_remover=pad_remover, mode=hparams.mode) else: # Using identity function: No effect dp_remove_pad = lambda x: x dp_restore_pad = lambda x: x if hparams.attention_exp_factor != 0: tf.logging.info("Expand/compress tokens before sending them to experts") dp_expand_bc = lambda x: dp( # pylint: disable=g-long-lambda expand_batch_coordinates, x, hparams.attention_exp_factor) dp_expand_x = lambda x: dp( # pylint: disable=g-long-lambda common_attention.deconv_elems_1d, x, hparams.attention_exp_factor, hparams.attention_exp_inputdim) dp_compress_x = lambda x, l: dp( # pylint: disable=g-long-lambda common_attention.conv_elems_1d, x, hparams.attention_exp_factor, l) else: dp_expand_bc = lambda x: x dp_expand_x = lambda x: x dp_compress_x = lambda x, l: x def print_shape(x, suffix, debug=False): # To help debugging, print the input/output shapes at inference and eval # Inference for long sequences can take a long time, so that's help to # see the progession of the generation if not debug and hparams.mode == ModeKeys.TRAIN: return x return tf.Print(x, [tf.shape(x)], "shape_x_{}".format(suffix)) with tf.name_scope("batch_coordinate_preprocess"): batch_coordinate = dp(get_batch_coordinate, x) batch_coordinate = dp_remove_pad(batch_coordinate) batch_coordinate = dp_expand_bc(batch_coordinate) batch_order = dp(get_batch_coordinate, x, axis=-1) batch_order = dp_remove_pad(batch_order) batch_order = dp_expand_bc(batch_order) x = dp(print_shape, x, "in") assert hparams.batch_size >= hparams.max_length num_hidden_layers = ( len(hparams.attention_layers) or hparams.num_hidden_layers) for layer in xrange(num_hidden_layers): with tf.variable_scope("layer_%d" % layer): # Use the layer type defined in attention_layers if hparams.attention_layers: attention_type = LAYER_SYMBOLS[hparams.attention_layers[layer]] else: attention_type = hparams.attention_type with tf.variable_scope( "attention_{}".format(attention_type)): if attention_type in [ AttentionType.MULTIHEAD, AttentionType.MULTIHEAD_FULL]: attention_dot_type = ( "local_mask_right" if hparams.attention_local else "dot_product") if attention_type == AttentionType.MULTIHEAD_FULL: attention_dot_type = "dot_product" y = dp( common_attention.multihead_attention, preprocess(x), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=attention_dot_type, block_length=hparams.attention_block_length, name="decoder_self_attention") elif attention_type == AttentionType.SPARSE_MULTIHEAD: x_in = preprocess(x) x_in = dp_remove_pad(x_in) y, loss_experts = dp( common_attention.multihead_attention_sparse_dot_prod, x_in, None, None, # Bias is computed inside hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, # Additional parameters bi=[common_attention.BatchInfo( coordinates=batch_coordinate[i], order=batch_order[i], # No future mask ) for i in range(dp.n)], use_map_fn=hparams.lsh_use_map_fn, experts_params=dict( nb_hyperplanes=hparams.lsh_num_hyperplanes, ), ) y = dp_restore_pad(y) # TODO(avaswani, epot, noam): Do we need to divide by num shards ? extra_loss += tf.add_n(loss_experts) / dp.n elif attention_type == AttentionType.SPARSE_MULTIHEAD_TRUNCATED: x_in = preprocess(x) y, loss_experts = dp( common_attention.multihead_attention_sparse_truncated, x_in, None, None, # Bias is computed inside hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, # Additional parameters bi=[common_attention.BatchInfo( coordinates=batch_coordinate[i], order=batch_order[i], # No future mask ) for i in range(dp.n)], mask_right=True, experts_params=dict( nb_hyperplanes=hparams.lsh_num_hyperplanes, ), ) # TODO(avaswani, epot, noam): Do we need to divide by num shards ? extra_loss += tf.add_n(loss_experts) / dp.n elif attention_type == AttentionType.MEMORY_EFFICIENT: assert hparams.layer_preprocess_sequence == "n" y = dp( common_attention.multihead_self_attention_memory_efficient, x, decoder_self_attention_bias, hparams.num_heads, name="decoder_self_attention") elif attention_type == AttentionType.MULTIHEAD_REDUCED: y = dp( common_attention.multihead_self_attention_reduced, preprocess(x), factor=hparams.attention_red_factor, reduction_type=hparams.attention_reduction_type, nonlinearity=hparams.attention_nonlinearity, multihead_params=dict( total_key_depth= hparams.attention_key_channels or hparams.hidden_size, total_value_depth= hparams.attention_value_channels or hparams.hidden_size, num_heads=hparams.num_heads, dropout_rate=hparams.attention_dropout, )) elif attention_type == AttentionType.LOCAL_EXPERTS: x_in = preprocess(x) x_in = dp_remove_pad(x_in) x_in = dp_expand_x(x_in) y, loss = dp( common_attention.local_expert_attention, x_in, k=hparams.attention_moe_k, loss_coef=hparams.attention_load_balance, attention_num_experts=hparams.attention_num_experts, train=hparams.mode == ModeKeys.TRAIN, batch_coordinate=batch_coordinate, mask_right=not hparams.use_inputs, split_batch=bool(hparams.attention_split_batch), attention_num_head=hparams.attention_num_head, attention_kq_size=hparams.attention_kq_size, attention_v_size=hparams.attention_v_size) y = dp_compress_x(y, x[0].get_shape().as_list()[-1]) y = dp_restore_pad(y) # TODO(avaswani, epot, noam): Do we need to divide by num shards ? extra_loss += tf.add_n(loss) / dp.n else: raise ValueError("Only {} supported for now.".format( AttentionType.get_choices())) x = postprocess(x, y) with tf.variable_scope("ffn"): if str(layer) in hparams.moe_layers.split(","): y, loss = expert_utils.distributed_moe( dp, self._ps_devices, preprocess(x), hparams.mode == ModeKeys.TRAIN, input_size=hparams.hidden_size, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) extra_loss += loss elif hparams.memory_efficient_ffn: assert hparams.layer_preprocess_sequence == "n" y = dp( common_layers.conv_hidden_relu_memory_efficient, x, hparams.filter_size) else: additional_conv_params = dict() if hparams.use_sepconv: additional_conv_params = dict( padding="LEFT", # Parameters copied from the transformer model kernel_size=(3, 1), second_kernel_size=(31, 1), ) y = dp( common_layers.conv_hidden_relu, preprocess(x), hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout, **additional_conv_params ) x = postprocess(x, y) x = preprocess(x) decoder_output = dp(tf.expand_dims, x, 2) return decoder_output, extra_loss
def body_sharded(self, sharded_features): train = self._hparams.mode == tf.estimator.ModeKeys.TRAIN dp = self._data_parallelism hparams = self._hparams def project_to_hidden(inputs): return common_layers.conv_block( inputs, hparams.hidden_size, [((1, 1), (3, 3))], first_relu=False, padding="SAME", force2d=True) def flatten(inputs): return tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2) # Project to hidden size if necessary if (sharded_features["inputs"][0].get_shape().as_list()[-1] != hparams.hidden_size): inputs = dp(project_to_hidden, sharded_features["inputs"]) inputs = dp(flatten, inputs) inputs_pad = dp(slicenet.embedding_to_padding, inputs) inputs_mask = dp(lambda x: 1.0 - x, inputs_pad) inputs_encoded = dp(common_layers.add_timing_signal, inputs) expert_loss = 0.0 for i in xrange(hparams.num_hidden_layers): with tf.variable_scope("enc_layer_%d" % i): inputs_encoded, moe_loss = conv_experts(inputs_encoded, hparams, dp, self._ps_devices, "SAME", inputs_mask, i) expert_loss += tf.reduce_mean(moe_loss) * hparams.moe_loss_coef # If we're just predicing a class, there is no use for a decoder, return. if isinstance(hparams.problems[self._problem_idx].target_modality, modalities.ClassLabelModality): return inputs_encoded, tf.reduce_mean(expert_loss) # Decoder. inputs3d = dp(tf.squeeze, inputs, 2) inputs_encoded3d = dp(tf.squeeze, inputs_encoded, 2) encoder_padding = dp(common_attention.embedding_to_padding, inputs3d) encoder_attention_bias = dp(common_attention.attention_bias_ignore_padding, encoder_padding) targets = dp(common_layers.flatten4d3d, sharded_features["targets"]) target_space_emb = dp(slicenet.embed_target_space, sharded_features["target_space_id"], hparams.hidden_size) (decoder_input, decoder_self_attention_bias) = dp(prepare_decoder, targets, target_space_emb) moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")] expert_fn = expert_utils.ffn_expert_fn( hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) x = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.dropout) for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("dec_layer_%d" % layer): with tf.variable_scope("attention"): y = dp( common_attention.multihead_attention, x, None, decoder_self_attention_bias, hparams.hidden_size, hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, name="decoder_self_attention") z = dp( common_attention.multihead_attention, y, inputs_encoded3d, encoder_attention_bias, hparams.hidden_size, hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, name="encdec_attention") x = dp(residual_fn3, x, y, z, hparams) with tf.variable_scope("ffn"): if str(layer) in hparams.moe_layers.split(","): y, moe_loss = expert_utils.distributed_moe( dp, self._ps_devices, x, train, input_size=hparams.hidden_size, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) expert_loss += tf.reduce_mean(moe_loss) else: y = dp( common_layers.conv_hidden_relu, x, hparams.filter_size, hparams.hidden_size, dropout=hparams.dropout) x = dp(residual_fn2, x, y, hparams) x = dp(tf.expand_dims, x, 2) return x, tf.reduce_mean(expert_loss)
def model_fn_body_sharded(self, sharded_features): # Remove dropout if not training hparams = self._hparams dp = self._data_parallelism x = dp(tf.squeeze, sharded_features["inputs"], 2) def preprocess(x): return dp(common_layers.layer_preprocess, x, hparams) def postprocess(x, y): return dp(common_layers.layer_postprocess, x, y, hparams) x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout) extra_loss = 0.0 ffn_hidden_sizes = [ int(s) for s in hparams.ffn_hidden_sizes.split(",") ] moe_hidden_sizes = [ int(s) for s in hparams.moe_hidden_sizes.split(",") ] if hparams.mask_right: def _bias(x): return common_attention.attention_bias_lower_triangle( tf.shape(x)[1]) bias = dp(_bias, x) else: bias = tf.zeros([1, 1, 1, 1]) if hparams.diet_experts: hsize, = moe_hidden_sizes def _diet_expert(x): return diet.diet_expert(x, hsize, diet.diet_adam_optimizer_params()) expert_fn = _diet_expert else: expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) batch_coordinate = dp(get_batch_coordinate, x) layers = hparams.layers.strip(",").split(",") for layer_num, layer_type in enumerate(layers): with tf.variable_scope("%s_%d" % (layer_type, layer_num)): if _should_preprocess(layer_type): x = preprocess(x) if layer_type == "timing": y = dp(common_attention.add_timing_signal_nd, x) elif layer_type == "pos_emb": y = dp(common_attention.add_positional_embedding_nd, x, hparams.max_length, name="pos_emb") elif layer_type == "att": y = dp( common_attention.multihead_attention, x, None, bias, # bias hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) elif layer_type == "att_grouped": multiplicative_overhead = ( hparams.multiplicative_overhead if hparams.mode == ModeKeys.TRAIN else hparams.multiplicative_overhead_eval) y, loss = dp( common_attention.grouped_attention_multihead, x, x, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, num_groups=hparams.attention_num_groups, memory_target_density=hparams.memory_target_density, multiplicative_overhead=multiplicative_overhead, make_image_summary=hparams.attention_image_summary, mask_right=hparams.mask_right, ) extra_loss += tf.add_n(loss) / dp.n elif layer_type == "att_memory_efficient": assert hparams.layer_preprocess_sequence == "n" y = dp( common_attention. multihead_self_attention_memory_efficient, x, bias, hparams.num_heads) elif layer_type == "att_local": y = dp( common_attention.multihead_attention, x, None, None, # bias hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=("local_mask_right" if hparams.mask_right else "local_unmasked"), block_length=hparams.local_attention_window, block_width=hparams.local_attention_window) elif layer_type == "att_pseudolocal": # This is an inefficient implementation of local attention, for the # purpose of testing model quality. def _pseudolocal_bias(x): return common_attention.attention_bias_local( tf.shape(x)[1], hparams.local_attention_window, 0 if hparams.mask_right else hparams.local_attention_window) pseudolocal_bias = dp(_pseudolocal_bias, x) y = dp( common_attention.multihead_attention, x, None, pseudolocal_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) elif layer_type == "att_local_expert": y, loss = dp( common_attention.local_expert_attention, x, k=hparams.attention_moe_k, loss_coef=hparams.attention_load_balance, attention_num_experts=hparams.attention_num_experts, train=hparams.mode == ModeKeys.TRAIN, batch_coordinate=batch_coordinate, mask_right=hparams.mask_right, split_batch=bool(hparams.attention_split_batch), attention_kq_size=hparams.attention_kq_size, attention_v_size=hparams.attention_v_size) # TODO(avaswani, epot, noam): Do we need to divide by num shards ? extra_loss += tf.add_n(loss) / dp.n elif layer_type == "att_lsh": if hparams.lsh_truncated: attention_fn = common_attention.multihead_attention_sparse_truncated else: attention_fn = common_attention.multihead_attention_sparse_dot_prod y, loss = dp( attention_fn, x, None, None, # Bias is computed inside hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, # Additional parameters bi=[ common_attention.BatchInfo( coordinates=batch_coordinate[i], order=None, # No future mask ) for i in range(dp.n) ], use_map_fn=False, experts_params=dict(nb_hyperplanes=4, )) extra_loss += tf.add_n(loss) / dp.n elif layer_type == "moe": y, loss = expert_utils.distributed_moe( dp, self._ps_devices, x, hparams.mode == ModeKeys.TRAIN, input_size=hparams.hidden_size, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) extra_loss += loss elif layer_type == "ffn": y = dp( expert_utils.ffn_expert_fn(hparams.hidden_size, ffn_hidden_sizes, hparams.hidden_size), dp(expert_utils.flatten_all_but_last, x)) y = dp(expert_utils.reshape_like, y, x) elif layer_type == "conv": y = dp( common_layers.conv1d, x, hparams.hidden_size, hparams.kernel_height, activation=tf.nn.relu, padding="SAME", ) else: assert False, "unknown sublayer %s" % layer_type if _should_postprocess(layer_type): x = postprocess(x, y) else: x = y x = preprocess(x) decoder_output = dp(tf.expand_dims, x, 2) return decoder_output, extra_loss
def body_sharded(self, sharded_features): # Remove dropout if not training hparams = self._hparams dp = self._data_parallelism x = dp(tf.squeeze, sharded_features["inputs"], 2) def preprocess(x): return dp(common_layers.layer_preprocess, x, hparams) def postprocess(x, y): return dp(common_layers.layer_postprocess, x, y, hparams) x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout) extra_loss = 0.0 ffn_hidden_sizes = [int(s) for s in hparams.ffn_hidden_sizes.split(",")] moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")] if hparams.mask_right: def _bias(x): return common_attention.attention_bias_lower_triangle( common_layers.shape_list(x)[1]) bias = dp(_bias, x) else: bias = tf.zeros([1, 1, 1, 1]) if hparams.diet_experts: hsize, = moe_hidden_sizes def _diet_expert(x): return diet.diet_expert(x, hsize, diet.diet_adam_optimizer_params()) expert_fn = _diet_expert else: expert_fn = expert_utils.ffn_expert_fn( hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) batch_coordinate = dp(get_batch_coordinate, x) layers = hparams.layers.strip(",").split(",") for layer_num, layer_type in enumerate(layers): with tf.variable_scope("%s_%d" % (layer_type, layer_num)): if _should_preprocess(layer_type): x = preprocess(x) if layer_type == "timing": y = dp(common_attention.add_timing_signal_nd, x) elif layer_type == "pos_emb": y = dp( common_attention.add_positional_embedding_nd, x, hparams.max_length, name="pos_emb") elif layer_type == "att": y = dp( common_attention.multihead_attention, x, None, bias, # bias hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) elif layer_type == "att_grouped": multiplicative_overhead = ( hparams.multiplicative_overhead if hparams.mode == ModeKeys.TRAIN else hparams.multiplicative_overhead_eval) y, loss = dp( common_attention.grouped_attention_multihead, x, x, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, num_groups=hparams.attention_num_groups, memory_target_density=hparams.memory_target_density, multiplicative_overhead=multiplicative_overhead, make_image_summary=hparams.attention_image_summary, mask_right=hparams.mask_right, ) extra_loss += tf.add_n(loss) / dp.n elif layer_type == "att_memory_efficient": assert hparams.layer_preprocess_sequence == "n" y = dp(common_attention.multihead_self_attention_memory_efficient, x, bias, hparams.num_heads) elif layer_type == "att_local": y = dp( common_attention.multihead_attention, x, None, None, # bias hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=("local_mask_right" if hparams.mask_right else "local_unmasked"), block_length=hparams.local_attention_window, block_width=hparams.local_attention_window) elif layer_type == "att_pseudolocal": # This is an inefficient implementation of local attention, for the # purpose of testing model quality. def _pseudolocal_bias(x): return common_attention.attention_bias_local( common_layers.shape_list(x)[1], hparams.local_attention_window, 0 if hparams.mask_right else hparams.local_attention_window) pseudolocal_bias = dp(_pseudolocal_bias, x) y = dp(common_attention.multihead_attention, x, None, pseudolocal_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) elif layer_type == "att_local_expert": y, loss = dp( common_attention.local_expert_attention, x, k=hparams.attention_moe_k, loss_coef=hparams.attention_load_balance, attention_num_experts=hparams.attention_num_experts, train=hparams.mode == ModeKeys.TRAIN, batch_coordinate=batch_coordinate, mask_right=hparams.mask_right, split_batch=bool(hparams.attention_split_batch), attention_kq_size=hparams.attention_kq_size, attention_v_size=hparams.attention_v_size) # TODO(avaswani, epot, noam): Do we need to divide by num shards ? extra_loss += tf.add_n(loss) / dp.n elif layer_type == "att_lsh": if hparams.lsh_truncated: attention_fn = common_attention.multihead_attention_sparse_truncated else: attention_fn = common_attention.multihead_attention_sparse_dot_prod y, loss = dp( attention_fn, x, None, None, # Bias is computed inside hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, # Additional parameters bi=[ common_attention.BatchInfo( coordinates=batch_coordinate[i], order=None, # No future mask ) for i in range(dp.n) ], use_map_fn=False, experts_params=dict(nb_hyperplanes=4,)) extra_loss += tf.add_n(loss) / dp.n elif layer_type == "moe": y, loss = expert_utils.distributed_moe( dp, self._ps_devices, x, hparams.mode == ModeKeys.TRAIN, input_size=hparams.hidden_size, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) extra_loss += loss elif layer_type == "ffn": y = dp( expert_utils.ffn_expert_fn(hparams.hidden_size, ffn_hidden_sizes, hparams.hidden_size), dp(expert_utils.flatten_all_but_last, x)) y = dp(expert_utils.reshape_like, y, x) elif layer_type == "conv": y = dp( common_layers.conv1d, x, hparams.hidden_size, hparams.kernel_height, activation=tf.nn.relu, padding="SAME", ) else: assert False, "unknown sublayer %s" % layer_type if _should_postprocess(layer_type): x = postprocess(x, y) else: x = y x = preprocess(x) decoder_output = dp(tf.expand_dims, x, 2) return decoder_output, extra_loss
def _super_stack(inputs, attention_bias, hparams, mp, padding="LEFT"): """A stack of super_lm layers. Args: inputs: a list of Tensors attention_bias: list of bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model mp: a Parallelism object padding: a string Returns: y: a list of Tensors extra_loss: an optional scalar """ layers = hparams.layers.strip(",").split(",") moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")] if hparams.diet_experts: hsize, = moe_hidden_sizes def _diet_expert(x): return diet.diet_expert(x, hsize, diet.diet_adam_optimizer_params()) expert_fn = _diet_expert else: expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) # scaled_dot_product_attention_with_projections uses a 3d attention bias # (no heads), where multihead_attention uses 4d attention bias. attention_bias_3d = mp(tf.squeeze, attention_bias, 1) mix_size = int(hparams.mix_fraction * hparams.hidden_size) accumulator = inputs x = inputs extra_losses = [] for layer_num, layer_type in enumerate(layers): with tf.variable_scope("%s_%d" % (layer_type, layer_num)): tf.logging.info("%s_%d" % (layer_type, layer_num)) if layer_type == "a": # accumulate accumulator = mp(tf.add, x, accumulator) x = accumulator elif layer_type == "n": # normalize x = mp(common_layers.apply_norm, x, hparams.norm_type, hparams.hidden_size, hparams.norm_epsilon) elif layer_type == "d": # dropout x = mp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout) elif layer_type == "m": # mix across shards def _split(t): return tuple( tf.split(t, [mix_size, hparams.hidden_size - mix_size], 2)) to_mix, to_keep = mp(_split, x) mixed = expert_utils.all_reduce_ring(to_mix, mp) mixed = mp(tf.multiply, mixed, mp.n**-0.5) x = mp(lambda a, b: tf.concat([a, b], 2), mixed, to_keep) elif layer_type == "att": # single-head attention q = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False, name="q_transform") x = mp(common_attention.scaled_dot_product_attention_simple, q, x, x, attention_bias_3d) x = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False, name="o_transform") elif layer_type == "multihead-att": # multi-head attention x = mp( common_attention.multihead_attention, x, None, attention_bias, # bias hparams.multihead_attention_key_channels or hparams.hidden_size, hparams.multihead_attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.multihead_attention_num_heads, hparams.attention_dropout) elif layer_type == "ffn": x = mp(common_layers.dense_relu_dense, x, hparams.filter_size, hparams.hidden_size) elif layer_type == "conv": # convolution x = mp( common_layers.conv1d, x, hparams.hidden_size, hparams.kernel_height, activation=tf.nn.relu, padding=padding, ) elif layer_type == "moe": # mixture of experts - each model shard has its own local MoE. x, loss = mp(expert_utils.local_moe, x, train=hparams.mode == tf_estimator.ModeKeys.TRAIN, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) extra_losses.extend(loss) else: assert False, "unknown sublayer %s" % layer_type if extra_losses: extra_loss = tf.add_n(extra_losses) else: extra_loss = None return x, extra_loss
def model_fn_body_sharded(self, sharded_features): hparams = self._hparams dp = self._data_parallelism targets = sharded_features["targets"] inputs = sharded_features["inputs"] target_space = sharded_features["target_space_id"] inputs = dp(common_layers.flatten4d3d, inputs) targets = dp(common_layers.flatten4d3d, targets) def preprocess(x): return dp(common_layers.layer_preprocess, x, hparams) def postprocess(x, y): return dp(common_layers.layer_postprocess, x, y, hparams) (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias) = dp( transformer.transformer_prepare_encoder, inputs, target_space, hparams) (decoder_input, decoder_self_attention_bias) = dp( transformer.transformer_prepare_decoder, targets, hparams) encoder_input = dp(tf.nn.dropout, encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_input = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) extra_loss = 0 moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")] expert_fn = expert_utils.ffn_expert_fn( hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) x = encoder_input for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("encoder_layer_%d" % layer): with tf.variable_scope("encoder_self_attention"): y = dp( common_attention.multihead_attention, preprocess(x), None, encoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) x = postprocess(x, y) with tf.variable_scope("ffn"): if str(layer) in hparams.moe_layers_encoder.split(","): y, loss = expert_utils.distributed_moe( dp, self._ps_devices, preprocess(x), hparams.mode == tf.contrib.learn.ModeKeys.TRAIN, input_size=hparams.hidden_size, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) extra_loss += loss else: y = dp( common_layers.conv_hidden_relu, preprocess(x), hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout) x = postprocess(x, y) encoder_output = preprocess(x) x = decoder_input for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("decoder_layer_%d" % layer): with tf.variable_scope("decoder_self_attention"): y = dp( common_attention.multihead_attention, preprocess(x), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) x = postprocess(x, y) with tf.variable_scope("encoder_decoder_attention"): y = dp( common_attention.multihead_attention, preprocess(x), encoder_output, encoder_decoder_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) x = postprocess(x, y) with tf.variable_scope("ffn"): if str(layer) in hparams.moe_layers_decoder.split(","): y, loss = expert_utils.distributed_moe( dp, self._ps_devices, preprocess(x), hparams.mode == tf.contrib.learn.ModeKeys.TRAIN, input_size=hparams.hidden_size, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) extra_loss += loss else: y = dp( common_layers.conv_hidden_relu, preprocess(x), hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout) x = postprocess(x, y) x = preprocess(x) decoder_output = dp(tf.expand_dims, x, 2) return decoder_output, extra_loss
def model_fn_body_sharded(self, sharded_features): # Remove dropout if not training hparams = self._hparams dp = self._data_parallelism if hparams.use_inputs: decoder_input = dp(tf.squeeze, sharded_features["inputs"], 2) decoder_self_attention_bias = None else: targets = sharded_features["targets"] targets = dp(tf.squeeze, targets, 2) (decoder_input, decoder_self_attention_bias, pad_remover) = dp(attention_lm_moe_prepare_decoder, targets, hparams) def preprocess(x): return dp(common_layers.layer_preprocess, x, hparams) def postprocess(x, y): return dp(common_layers.layer_postprocess, x, y, hparams) x = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) extra_loss = 0.0 moe_hidden_sizes = [ int(s) for s in hparams.moe_hidden_sizes.split(",") ] if hparams.diet_experts: hsize, = moe_hidden_sizes def _diet_expert(x): return diet.diet_expert(x, hsize, diet.diet_adam_optimizer_params()) expert_fn = _diet_expert else: expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) if not hparams.use_inputs: # As preprocess and postprocess are called with batch of size one (all # batches concatenated), we just make sure that batch_norm is not use ( # should not either way) assert hparams.norm_type != "batch" tf.logging.info( "Applying Padding Remover for the attention experts") dp_remove_pad = functools.partial(dp, remove_pad, pad_remover=pad_remover, mode=hparams.mode) dp_restore_pad = functools.partial(dp, restore_pad, ref_x=x, pad_remover=pad_remover, mode=hparams.mode) else: # Using identity function: No effect dp_remove_pad = lambda x: x dp_restore_pad = lambda x: x if hparams.attention_exp_factor != 0: tf.logging.info( "Expand/compress tokens before sending them to experts") dp_expand_bc = lambda x: dp( # pylint: disable=g-long-lambda expand_batch_coordinates, x, hparams.attention_exp_factor) dp_expand_x = lambda x: dp( # pylint: disable=g-long-lambda common_attention.deconv_elems_1d, x, hparams. attention_exp_factor, hparams.attention_exp_inputdim) dp_compress_x = lambda x, l: dp( # pylint: disable=g-long-lambda common_attention.conv_elems_1d, x, hparams. attention_exp_factor, l) else: dp_expand_bc = lambda x: x dp_expand_x = lambda x: x dp_compress_x = lambda x, l: x def print_shape(x, suffix, debug=False): # To help debugging, print the input/output shapes at inference and eval # Inference for long sequences can take a long time, so that's help to # see the progession of the generation if not debug and hparams.mode == ModeKeys.TRAIN: return x return tf.Print(x, [tf.shape(x)], "shape_x_{}".format(suffix)) with tf.name_scope("batch_coordinate_preprocess"): batch_coordinate = dp(get_batch_coordinate, x) batch_coordinate = dp_remove_pad(batch_coordinate) batch_coordinate = dp_expand_bc(batch_coordinate) batch_order = dp(get_batch_coordinate, x, axis=-1) batch_order = dp_remove_pad(batch_order) batch_order = dp_expand_bc(batch_order) x = dp(print_shape, x, "in") assert hparams.batch_size >= hparams.max_length num_hidden_layers = (len(hparams.attention_layers) or hparams.num_hidden_layers) for layer in xrange(num_hidden_layers): with tf.variable_scope("layer_%d" % layer): # Use the layer type defined in attention_layers if hparams.attention_layers: attention_type = LAYER_SYMBOLS[ hparams.attention_layers[layer]] else: attention_type = hparams.attention_type with tf.variable_scope("attention_{}".format(attention_type)): if attention_type in [ AttentionType.MULTIHEAD, AttentionType.MULTIHEAD_FULL ]: attention_dot_type = ("local_mask_right" if hparams.attention_local else "dot_product") if attention_type == AttentionType.MULTIHEAD_FULL: attention_dot_type = "dot_product" y = dp(common_attention.multihead_attention, preprocess(x), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=attention_dot_type, block_length=hparams.attention_block_length, name="decoder_self_attention") elif attention_type == AttentionType.SPARSE_MULTIHEAD: x_in = preprocess(x) x_in = dp_remove_pad(x_in) y, loss_experts = dp( common_attention. multihead_attention_sparse_dot_prod, x_in, None, None, # Bias is computed inside hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, # Additional parameters bi=[ common_attention.BatchInfo( coordinates=batch_coordinate[i], order=batch_order[i], # No future mask ) for i in range(dp.n) ], use_map_fn=hparams.lsh_use_map_fn, experts_params=dict( nb_hyperplanes=hparams.lsh_num_hyperplanes, ), ) y = dp_restore_pad(y) # TODO(avaswani, epot, noam): Do we need to divide by num shards ? extra_loss += tf.add_n(loss_experts) / dp.n elif attention_type == AttentionType.SPARSE_MULTIHEAD_TRUNCATED: x_in = preprocess(x) y, loss_experts = dp( common_attention. multihead_attention_sparse_truncated, x_in, None, None, # Bias is computed inside hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, # Additional parameters bi=[ common_attention.BatchInfo( coordinates=batch_coordinate[i], order=batch_order[i], # No future mask ) for i in range(dp.n) ], mask_right=True, experts_params=dict( nb_hyperplanes=hparams.lsh_num_hyperplanes, ), ) # TODO(avaswani, epot, noam): Do we need to divide by num shards ? extra_loss += tf.add_n(loss_experts) / dp.n elif attention_type == AttentionType.MEMORY_EFFICIENT: assert hparams.layer_preprocess_sequence == "n" y = dp(common_attention. multihead_self_attention_memory_efficient, x, decoder_self_attention_bias, hparams.num_heads, name="decoder_self_attention") elif attention_type == AttentionType.MULTIHEAD_REDUCED: y = dp( common_attention.multihead_self_attention_reduced, preprocess(x), factor=hparams.attention_red_factor, reduction_type=hparams.attention_reduction_type, nonlinearity=hparams.attention_nonlinearity, multihead_params=dict( total_key_depth=hparams.attention_key_channels or hparams.hidden_size, total_value_depth=hparams. attention_value_channels or hparams.hidden_size, num_heads=hparams.num_heads, dropout_rate=hparams.attention_dropout, )) elif attention_type == AttentionType.LOCAL_EXPERTS: x_in = preprocess(x) x_in = dp_remove_pad(x_in) x_in = dp_expand_x(x_in) y, loss = dp( common_attention.local_expert_attention, x_in, k=hparams.attention_moe_k, loss_coef=hparams.attention_load_balance, attention_num_experts=hparams. attention_num_experts, train=hparams.mode == ModeKeys.TRAIN, batch_coordinate=batch_coordinate, mask_right=not hparams.use_inputs, split_batch=bool(hparams.attention_split_batch), attention_num_head=hparams.attention_num_head, attention_kq_size=hparams.attention_kq_size, attention_v_size=hparams.attention_v_size) y = dp_compress_x(y, x[0].get_shape().as_list()[-1]) y = dp_restore_pad(y) # TODO(avaswani, epot, noam): Do we need to divide by num shards ? extra_loss += tf.add_n(loss) / dp.n else: raise ValueError("Only {} supported for now.".format( AttentionType.get_choices())) x = postprocess(x, y) with tf.variable_scope("ffn"): if str(layer) in hparams.moe_layers.split(","): y, loss = expert_utils.distributed_moe( dp, self._ps_devices, preprocess(x), hparams.mode == ModeKeys.TRAIN, input_size=hparams.hidden_size, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) extra_loss += loss elif hparams.memory_efficient_ffn: assert hparams.layer_preprocess_sequence == "n" y = dp(common_layers.conv_hidden_relu_memory_efficient, x, hparams.filter_size) else: additional_conv_params = dict() if hparams.use_sepconv: additional_conv_params = dict( padding="LEFT", # Parameters copied from the transformer model kernel_size=(3, 1), second_kernel_size=(31, 1), ) y = dp(common_layers.conv_hidden_relu, preprocess(x), hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout, **additional_conv_params) x = postprocess(x, y) x = preprocess(x) decoder_output = dp(tf.expand_dims, x, 2) return decoder_output, extra_loss
def model_fn_body_sharded(self, sharded_features): train = self._hparams.mode == tf.estimator.ModeKeys.TRAIN dp = self._data_parallelism hparams = self._hparams def flatten(inputs): return tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2) inputs = dp(flatten, sharded_features["inputs"]) inputs_pad = dp(slicenet.embedding_to_padding, inputs) inputs_mask = dp(lambda x: 1.0 - x, inputs_pad) inputs_encoded = dp(common_layers.add_timing_signal, inputs) expert_loss = 0.0 for i in xrange(hparams.num_hidden_layers): with tf.variable_scope("enc_layer_%d" % i): inputs_encoded, moe_loss = conv_experts( inputs_encoded, hparams, dp, self._ps_devices, "SAME", inputs_mask, i) expert_loss += tf.reduce_mean(moe_loss) * hparams.moe_loss_coef # If we're just predicing a class, there is no use for a decoder, return. if isinstance(hparams.problems[self._problem_idx].target_modality, modalities.ClassLabelModality): return inputs_encoded, tf.reduce_mean(expert_loss) # Decoder. inputs3d = dp(tf.squeeze, inputs, 2) inputs_encoded3d = dp(tf.squeeze, inputs_encoded, 2) encoder_padding = dp(common_attention.embedding_to_padding, inputs3d) encoder_attention_bias = dp( common_attention.attention_bias_ignore_padding, encoder_padding) targets = dp(common_layers.flatten4d3d, sharded_features["targets"]) target_space_emb = dp(slicenet.embed_target_space, sharded_features["target_space_id"], hparams.hidden_size) (decoder_input, decoder_self_attention_bias) = dp(prepare_decoder, targets, target_space_emb) moe_hidden_sizes = [ int(s) for s in hparams.moe_hidden_sizes.split(",") ] expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) x = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.dropout) for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("dec_layer_%d" % layer): with tf.variable_scope("attention"): y = dp(common_attention.multihead_attention, x, None, decoder_self_attention_bias, hparams.hidden_size, hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, name="decoder_self_attention") z = dp(common_attention.multihead_attention, y, inputs_encoded3d, encoder_attention_bias, hparams.hidden_size, hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, name="encdec_attention") x = dp(residual_fn3, x, y, z, hparams) with tf.variable_scope("ffn"): if str(layer) in hparams.moe_layers.split(","): y, moe_loss = expert_utils.distributed_moe( dp, self._ps_devices, x, train, input_size=hparams.hidden_size, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) expert_loss += tf.reduce_mean(moe_loss) else: y = dp(common_layers.conv_hidden_relu, x, hparams.filter_size, hparams.hidden_size, dropout=hparams.dropout) x = dp(residual_fn2, x, y, hparams) x = dp(tf.expand_dims, x, 2) return x, tf.reduce_mean(expert_loss)
def transformer_ffn_layer(x, hparams, pad_remover=None, conv_padding="LEFT", nonpadding_mask=None, losses=None, cache=None, decode_loop_step=None, readout_filter_size=0): """Feed-forward layer in the transformer. Args: x: a Tensor of shape [batch_size, length, hparams.hidden_size] hparams: hyperparameters for model pad_remover: an expert_utils.PadRemover object tracking the padding positions. If provided, when using convolutional settings, the padding is removed before applying the convolution, and restored afterward. This can give a significant speedup. conv_padding: a string - either "LEFT" or "SAME". nonpadding_mask: an optional Tensor with shape [batch_size, length]. needed for convolutional layers with "SAME" padding. Contains 1.0 in positions corresponding to nonpadding. losses: optional list onto which to append extra training losses cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. decode_loop_step: An integer, step number of the decoding loop. Only used for inference on TPU. readout_filter_size: if it's greater than 0, then it will be used instead of filter_size Returns: a Tensor of shape [batch_size, length, hparams.hidden_size] Raises: ValueError: If losses arg is None, but layer generates extra losses. """ ffn_layer = hparams.ffn_layer relu_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "relu_dropout_broadcast_dims", ""))) if ffn_layer == "conv_hidden_relu": # Backwards compatibility ffn_layer = "dense_relu_dense" if ffn_layer == "dense_relu_dense": # In simple convolution mode, use `pad_remover` to speed up processing. mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_FFN_FILTER_DENSE, value={ "filter_size": hparams.filter_size, "use_bias": "True", "activation": mlperf_log.RELU }) mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_FFN_OUTPUT_DENSE, value={ "hidden_size": hparams.hidden_size, "use_bias": "True", }) mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_RELU_DROPOUT, value=hparams.relu_dropout) if pad_remover: original_shape = common_layers.shape_list(x) # Collapse `x` across examples, and remove padding positions. x = tf.reshape(x, tf.concat([[-1], original_shape[2:]], axis=0)) x = tf.expand_dims(pad_remover.remove(x), axis=0) conv_output = common_layers.dense_relu_dense( x, hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout, dropout_broadcast_dims=relu_dropout_broadcast_dims) if pad_remover: # Restore `conv_output` to the original shape of `x`, including padding. conv_output = tf.reshape( pad_remover.restore(tf.squeeze(conv_output, axis=0)), original_shape) return conv_output elif ffn_layer == "conv_relu_conv": return common_layers.conv_relu_conv( x, readout_filter_size or hparams.filter_size, hparams.hidden_size, first_kernel_size=hparams.conv_first_kernel, second_kernel_size=1, padding=conv_padding, nonpadding_mask=nonpadding_mask, dropout=hparams.relu_dropout, cache=cache, decode_loop_step=decode_loop_step) elif ffn_layer == "parameter_attention": return common_attention.parameter_attention( x, hparams.parameter_attention_key_channels or hparams.hidden_size, hparams.parameter_attention_value_channels or hparams.hidden_size, hparams.hidden_size, readout_filter_size or hparams.filter_size, hparams.num_heads, hparams.attention_dropout) elif ffn_layer == "conv_hidden_relu_with_sepconv": return common_layers.conv_hidden_relu( x, readout_filter_size or hparams.filter_size, hparams.hidden_size, kernel_size=(3, 1), second_kernel_size=(31, 1), padding="LEFT", dropout=hparams.relu_dropout) elif ffn_layer == "sru": return common_layers.sru(x) elif ffn_layer == "local_moe_tpu": overhead = ( hparams.moe_overhead_train if hparams.mode == tf.estimator.ModeKeys.TRAIN else hparams.moe_overhead_eval) ret, loss = expert_utils.local_moe_tpu( x, hparams.filter_size // 2, hparams.hidden_size, hparams.moe_num_experts, overhead=overhead, loss_coef=hparams.moe_loss_coef) elif ffn_layer == "local_moe": overhead = ( hparams.moe_overhead_train if hparams.mode == tf.estimator.ModeKeys.TRAIN else hparams.moe_overhead_eval) ret, loss = expert_utils.local_moe( x, True, expert_utils.ffn_expert_fn(hparams.hidden_size, [hparams.filter_size], hparams.hidden_size), hparams.moe_num_experts, k=hparams.moe_k, hparams=hparams) losses.append(loss) return ret else: assert ffn_layer == "none" return x
def transformer_ffn_layer(x, hparams, pad_remover=None, conv_padding="LEFT", nonpadding_mask=None, losses=None, cache=None, decode_loop_step=None, readout_filter_size=0): """Feed-forward layer in the transformer. Args: x: a Tensor of shape [batch_size, length, hparams.hidden_size] hparams: hyperparameters for model pad_remover: an expert_utils.PadRemover object tracking the padding positions. If provided, when using convolutional settings, the padding is removed before applying the convolution, and restored afterward. This can give a significant speedup. conv_padding: a string - either "LEFT" or "SAME". nonpadding_mask: an optional Tensor with shape [batch_size, length]. needed for convolutional layers with "SAME" padding. Contains 1.0 in positions corresponding to nonpadding. losses: optional list onto which to append extra training losses cache: dict, containing tensors which are the results of previous attentions, used for fast decoding. decode_loop_step: An integer, step number of the decoding loop. Only used for inference on TPU. readout_filter_size: if it's greater than 0, then it will be used instead of filter_size Returns: a Tensor of shape [batch_size, length, hparams.hidden_size] Raises: ValueError: If losses arg is None, but layer generates extra losses. """ ffn_layer = hparams.ffn_layer relu_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "relu_dropout_broadcast_dims", ""))) if ffn_layer == "conv_hidden_relu": # Backwards compatibility ffn_layer = "dense_relu_dense" if ffn_layer == "dense_relu_dense": # In simple convolution mode, use `pad_remover` to speed up processing. mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_FFN_FILTER_DENSE, value={ "filter_size": hparams.filter_size, "use_bias": "True", "activation": mlperf_log.RELU }) mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_FFN_OUTPUT_DENSE, value={ "hidden_size": hparams.hidden_size, "use_bias": "True", }) mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_RELU_DROPOUT, value=hparams.relu_dropout) if pad_remover: original_shape = common_layers.shape_list(x) # Collapse `x` across examples, and remove padding positions. x = tf.reshape(x, tf.concat([[-1], original_shape[2:]], axis=0)) x = tf.expand_dims(pad_remover.remove(x), axis=0) conv_output = common_layers.dense_relu_dense( x, hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout, dropout_broadcast_dims=relu_dropout_broadcast_dims) if pad_remover: # Restore `conv_output` to the original shape of `x`, including padding. conv_output = tf.reshape( pad_remover.restore(tf.squeeze(conv_output, axis=0)), original_shape) return conv_output elif ffn_layer == "conv_relu_conv": return common_layers.conv_relu_conv( x, readout_filter_size or hparams.filter_size, hparams.hidden_size, first_kernel_size=hparams.conv_first_kernel, second_kernel_size=1, padding=conv_padding, nonpadding_mask=nonpadding_mask, dropout=hparams.relu_dropout, cache=cache, decode_loop_step=decode_loop_step) elif ffn_layer == "parameter_attention": return common_attention.parameter_attention( x, hparams.parameter_attention_key_channels or hparams.hidden_size, hparams.parameter_attention_value_channels or hparams.hidden_size, hparams.hidden_size, readout_filter_size or hparams.filter_size, hparams.num_heads, hparams.attention_dropout) elif ffn_layer == "conv_hidden_relu_with_sepconv": return common_layers.conv_hidden_relu( x, readout_filter_size or hparams.filter_size, hparams.hidden_size, kernel_size=(3, 1), second_kernel_size=(31, 1), padding="LEFT", dropout=hparams.relu_dropout) elif ffn_layer == "sru": return common_layers.sru(x) elif ffn_layer == "local_moe_tpu": overhead = ( hparams.moe_overhead_train if hparams.mode == tf.estimator.ModeKeys.TRAIN else hparams.moe_overhead_eval) ret, loss = expert_utils.local_moe_tpu( x, hparams.filter_size // 2, hparams.hidden_size, hparams.moe_num_experts, overhead=overhead, loss_coef=hparams.moe_loss_coef) elif ffn_layer == "local_moe": overhead = ( hparams.moe_overhead_train if hparams.mode == tf.estimator.ModeKeys.TRAIN else hparams.moe_overhead_eval) ret, loss = expert_utils.local_moe( x, True, expert_utils.ffn_expert_fn(hparams.hidden_size, [hparams.filter_size], hparams.hidden_size), hparams.moe_num_experts, k=hparams.moe_k, hparams=hparams) losses.append(loss) return ret else: assert ffn_layer == "none" return x
def model_fn_body_sharded(self, sharded_features): # Remove dropout if not training hparams = self._hparams dp = self._data_parallelism targets = sharded_features["targets"] targets = dp(tf.squeeze, targets, 2) def preprocess(x): return dp(common_layers.layer_preprocess, x, hparams) def postprocess(x, y): return dp(common_layers.layer_postprocess, x, y, hparams) (decoder_input, decoder_self_attention_bias) = dp(attention_lm_moe_prepare_decoder, targets, hparams) x = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) extra_loss = 0.0 moe_hidden_sizes = [ int(s) for s in hparams.moe_hidden_sizes.split(",") ] if hparams.diet_experts: hsize, = moe_hidden_sizes def _diet_expert(x): return diet.diet_expert(x, hsize, diet.diet_adam_optimizer_params()) expert_fn = _diet_expert else: expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("attention_{}".format( hparams.attention_moe_type)): x = preprocess(x) if hparams.attention_moe_type == AttentionMoeType.NONE: y = dp(common_attention.multihead_attention, x, None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, name="decoder_self_attention") elif hparams.attention_moe_type == AttentionMoeType.LOCAL: y, loss = dp( common_attention.local_expert_attention, x, k=2, loss_coef=1e-2, attention_num_experts=hparams. attention_num_experts, train=hparams.mode == tf.contrib.learn.ModeKeys.TRAIN, mask_right=True, attention_kq_size=hparams.attention_kq_size, attention_v_size=hparams.attention_v_size) # TODO(avaswani, epot, noam): Do we need to divide by num shards ? extra_loss += tf.add_n(loss) / dp.n else: raise ValueError("Only {} supported for now.".format( AttentionMoeType.get_choices())) x = postprocess(x, y) with tf.variable_scope("ffn"): if str(layer) in hparams.moe_layers.split(","): y, loss = expert_utils.distributed_moe( dp, self._ps_devices, preprocess(x), hparams.mode == tf.contrib.learn.ModeKeys.TRAIN, input_size=hparams.hidden_size, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) extra_loss += loss else: y = dp(common_layers.conv_hidden_relu, preprocess(x), hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout) x = postprocess(x, y) x = preprocess(x) decoder_output = dp(tf.expand_dims, x, 2) return decoder_output, extra_loss
def _super_stack(inputs, attention_bias, hparams, mp, padding="LEFT"): """A stack of super_lm layers. Args: inputs: a list of Tensors attention_bias: list of bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model mp: a Parallelism object padding: a string Returns: y: a Tensors """ layers = hparams.layers.strip(",").split(",") ffn_hidden_sizes = [int(s) for s in hparams.ffn_hidden_sizes.split(",")] # scaled_dot_product_attention_with_projections uses a 3d attention bias # (no heads), where multihead_attention uses 4d attention bias. mix_size = int(hparams.mix_fraction * hparams.hidden_size) attention_bias_3d = mp(tf.squeeze, attention_bias, 1) accumulator = inputs x = inputs for layer_num, layer_type in enumerate(layers): with tf.variable_scope("%s_%d" % (layer_type, layer_num)): tf.logging.info("%s_%d" % (layer_type, layer_num)) if layer_type == "a": # accumulate accumulator = mp(tf.add, x, accumulator) x = accumulator elif layer_type == "n": # normalize x = mp(common_layers.apply_norm, x, hparams.norm_type, hparams.hidden_size, hparams.norm_epsilon) elif layer_type == "d": # dropout x = mp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout) elif layer_type == "m": # mix across shards def _split(t): return tuple( tf.split(t, [mix_size, hparams.hidden_size - mix_size], 2)) to_mix, to_keep = mp(_split, x) mixed = common_layers.all_reduce_ring(to_mix, mp) mixed = mp(tf.multiply, mixed, mp.n**-0.5) x = mp(lambda a, b: tf.concat([a, b], 2), mixed, to_keep) elif layer_type == "att": # single-head attention q = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False, name="q_transform") x = mp(common_attention.scaled_dot_product_attention_simple, q, x, x, attention_bias_3d) x = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False, name="o_transform") elif layer_type == "multihead-att": # multi-head attention x = mp( common_attention.multihead_attention, x, None, attention_bias, # bias hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) elif layer_type == "ffn": y = mp( expert_utils.ffn_expert_fn(hparams.hidden_size, ffn_hidden_sizes, hparams.hidden_size), mp(expert_utils.flatten_all_but_last, x)) x = mp(expert_utils.reshape_like, y, x) elif layer_type == "conv": # convolution x = mp( common_layers.conv1d, x, hparams.hidden_size, hparams.kernel_height, activation=tf.nn.relu, padding=padding, ) else: assert False, "unknown sublayer %s" % layer_type return x
def _super_stack(inputs, attention_bias, hparams, mp, padding="LEFT"): """A stack of super_lm layers. Args: inputs: a list of Tensors attention_bias: list of bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model mp: a Parallelism object padding: a string Returns: y: a list of Tensors extra_loss: an optional scalar """ layers = hparams.layers.strip(",").split(",") moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")] if hparams.diet_experts: hsize, = moe_hidden_sizes def _diet_expert(x): return diet.diet_expert(x, hsize, diet.diet_adam_optimizer_params()) expert_fn = _diet_expert else: expert_fn = expert_utils.ffn_expert_fn( hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) # scaled_dot_product_attention_with_projections uses a 3d attention bias # (no heads), where multihead_attention uses 4d attention bias. attention_bias_3d = mp(tf.squeeze, attention_bias, 1) mix_size = int(hparams.mix_fraction * hparams.hidden_size) accumulator = inputs x = inputs extra_losses = [] for layer_num, layer_type in enumerate(layers): with tf.variable_scope("%s_%d" % (layer_type, layer_num)): tf.logging.info("%s_%d" % (layer_type, layer_num)) if layer_type == "a": # accumulate accumulator = mp(tf.add, x, accumulator) x = accumulator elif layer_type == "n": # normalize x = mp(common_layers.apply_norm, x, hparams.norm_type, hparams.hidden_size, hparams.norm_epsilon) elif layer_type == "d": # dropout x = mp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout) elif layer_type == "m": # mix across shards def _split(t): return tuple(tf.split( t, [mix_size, hparams.hidden_size - mix_size], 2)) to_mix, to_keep = mp(_split, x) mixed = common_layers.all_reduce_ring(to_mix, mp) mixed = mp(tf.multiply, mixed, mp.n ** -0.5) x = mp(lambda a, b: tf.concat([a, b], 2), mixed, to_keep) elif layer_type == "att": # single-head attention q = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False, name="q_transform") x = mp( common_attention.scaled_dot_product_attention_simple, q, x, x, attention_bias_3d) x = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False, name="o_transform") elif layer_type == "multihead-att": # multi-head attention x = mp( common_attention.multihead_attention, x, None, attention_bias, # bias hparams.multihead_attention_key_channels or hparams.hidden_size, hparams.multihead_attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.multihead_attention_num_heads, hparams.attention_dropout) elif layer_type == "ffn": x = mp( common_layers.dense_relu_dense, x, hparams.filter_size, hparams.hidden_size) elif layer_type == "conv": # convolution x = mp( common_layers.conv1d, x, hparams.hidden_size, hparams.kernel_height, activation=tf.nn.relu, padding=padding, ) elif layer_type == "moe": # mixture of experts - each model shard has its own local MoE. x, loss = mp( expert_utils.local_moe, x, train=hparams.mode == tf.estimator.ModeKeys.TRAIN, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) extra_losses.extend(loss) else: assert False, "unknown sublayer %s" % layer_type if extra_losses: extra_loss = tf.add_n(extra_losses) else: extra_loss = None return x, extra_loss
def model_fn_body_sharded(self, sharded_features): # ========= Prepare the input and target ========= hparams = self._hparams dp = self._data_parallelism targets = sharded_features["targets"] inputs = sharded_features["inputs"] target_space = sharded_features["target_space_id"] inputs = dp(common_layers.flatten4d3d, inputs) targets = dp(common_layers.flatten4d3d, targets) def dp_preprocess(x): return dp(common_layers.layer_preprocess, x, hparams) def dp_postprocess(x, y): return dp(common_layers.layer_postprocess, x, y, hparams) (encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias) = dp( transformer.transformer_prepare_encoder, inputs, target_space, hparams) (decoder_input, decoder_self_attention_bias) = dp( transformer.transformer_prepare_decoder, targets, hparams) encoder_input = dp(tf.nn.dropout, encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_input = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) cache = dict(extra_loss=0) moe_hidden_sizes = [ int(s) for s in hparams.moe_hidden_sizes.split(",") ] expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) # ========= Define some utils decorators ========= def prepostprocess(fct): """Add pre and post processing.""" # WARNING: Should be applied after dp (pre/post-process use dp and # can be applied to function which doesn't use dp) @functools.wraps(fct) def decorated(x, *args, **kwargs): x = dp_preprocess(x) y = fct(x, *args, **kwargs) return dp_postprocess(x, y) return decorated def dp_wrapper(fct): """Encapsulate the function in a data parallelism object.""" @functools.wraps(fct) def decorated(*args, **kwargs): return dp(fct, *args, **kwargs) return decorated def add_kwargs( fct, enco_kwargs=None, deco_kwargs=None, endeco_kwargs=None, # Enco-deco attention: overwrite deco_kwargs ): """Allow to have different arguments for the encoder and decoder.""" # WARNING: If this decorator is applied before dp_wrapper, the kwargs # may not be correctly dipatched across the devices. @functools.wraps(fct) def decorated(*args, **kwargs): current_scope = tf.contrib.framework.get_name_scope() if "/encoder/" in current_scope: kwargs.update(enco_kwargs or {}) elif "/decoder/" in current_scope: kwargs.update(deco_kwargs or {}) if "/att_ende_" in current_scope: kwargs.update(endeco_kwargs or {}) return fct(*args, **kwargs) return decorated def capture_extra_loss(fct, loss_coef=1.0): """Capture the additional loss.""" @functools.wraps(fct) def decorated(*args, **kwargs): y, loss = fct(*args, **kwargs) cache["extra_loss"] += loss * loss_coef return y return decorated def remove_kwargs(fct, extra_params): """Remove some unused parameters.""" @functools.wraps(fct) def decorated(*args, **kwargs): for k in extra_params: # Remove the extra params kwargs.pop(k, None) return fct(*args, **kwargs) return decorated # def pad_remover(fct): # """Remove/restore the padding on the input.""" # @functools.wraps(fct) # def decorated(x, *args, **kwargs): # x = pad_remover.remove(x) # x = fct(x, *args, **kwargs) # x = pad_remover.restore(x) # return x # return decorated # ========= Define the available layers ========= total_key_depth = hparams.attention_key_channels or hparams.hidden_size total_value_depth = hparams.attention_value_channels or hparams.hidden_size # Multi-head full attention layer multihead_attention = partial( common_attention.multihead_attention, total_key_depth=total_key_depth, total_value_depth=total_value_depth, output_depth=hparams.hidden_size, num_heads=hparams.num_heads, dropout_rate=hparams.attention_dropout, ) multihead_attention = dp_wrapper(multihead_attention) multihead_attention = add_kwargs( # After dp to correctly dispatch kwargs multihead_attention, enco_kwargs={"bias": encoder_self_attention_bias}, deco_kwargs={"bias": decoder_self_attention_bias}, endeco_kwargs={"bias": encoder_decoder_attention_bias}, ) multihead_attention = prepostprocess(multihead_attention) # Local attention layer # Reuse same parameters as multihead_attention (dp and pre/post-processing # already applied) # Only works for self attention. Always mask the future. local_attention = partial( multihead_attention, block_length=hparams.attention_loc_block_length, attention_type="local_mask_right", ) # Memory-compressed multihead self attention layer # Only works for self attention. Always mask the future. compressed_attention = partial( common_attention.multihead_self_attention_reduced, factor=hparams.attention_red_factor, nonlinearity=hparams.attention_red_nonlinearity, reduction_type=hparams.attention_red_type, multihead_params=dict( total_key_depth=total_key_depth, total_value_depth=total_value_depth, num_heads=hparams.num_heads, dropout_rate=hparams.attention_dropout, )) compressed_attention = remove_kwargs(compressed_attention, ["memory_antecedent"]) compressed_attention = dp_wrapper(compressed_attention) compressed_attention = prepostprocess(compressed_attention) # Mixture of expert layer distributed_moe = partial( expert_utils.distributed_moe, dp, self._ps_devices, train=hparams.mode == tf.estimator.ModeKeys.TRAIN, input_size=hparams.hidden_size, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) distributed_moe = capture_extra_loss(distributed_moe) distributed_moe = prepostprocess(distributed_moe) # FC layer conv_hidden_relu = partial( common_layers.conv_hidden_relu, hidden_size=hparams.filter_size, output_size=hparams.hidden_size, dropout=hparams.relu_dropout, ) conv_hidden_relu = dp_wrapper(conv_hidden_relu) conv_hidden_relu = prepostprocess(conv_hidden_relu) # Separable convolution layer # Reuse conv_hidden_relu (dp and pre/post-processing already applied) # Mask the future for the decoder only sep_conv_relu = partial( conv_hidden_relu, # Parameters copied from the transformer model, could add hparams kernel_size=(3, 1), second_kernel_size=(31, 1), ) sep_conv_relu = add_kwargs( sep_conv_relu, enco_kwargs={"padding": "SAME"}, deco_kwargs={"padding": "LEFT"}, # Mask future for decoder ) # This dictionary contains the list of all available layers available_layers = dict( # Attention layers a=multihead_attention, # Standard multihead full attention loc=local_attention, # Local attention red=compressed_attention, # Memory-compressed attention mem=None, # Memory efficient # Feed-forward layers moe=distributed_moe, # Mixture of expert layer sep=sep_conv_relu, # Separable convolution fc=conv_hidden_relu, # Fully connected ) def extract_layer_types(layer_types): """Parse the layer string. Args: layer_types (str): String containing the network architecture. See top file comment for examples of format. Returns: list[tuple[str, str]]: Encoder layers: list of (attention, feed-forward) list[tuple[str, str, str]]: Decoder layers: list of (self-attention, enc-dec attention, feed-forward) """ # If the architecture has not explicitly been set, we just construct a # standard transformer with the fallback values if not layer_types: layer_types = SEP_LAYER.join([hparams.default_att] * hparams.num_hidden_layers) # If encoder not explicitly defined, the encoder will have the same # structure as the decoder layer_types = layer_types.split(SEP_ENCODEC) if len(layer_types) == 1: layer_types *= 2 # Some models don't need the encoder (ex: language modeling) # TODO(epot): What are the other conditions (has_input ?) if hparams.prepend_mode != "none": layer_types[0] = "" # Extend the blocks and fill them with the default values if not specified final_layers = ([], []) for i, blocks_str in enumerate(layer_types): for blocks_str in blocks_str.split(SEP_LAYER): if not blocks_str: continue blocks_list = blocks_str.split(SEP_FF) # Eventually use the fallback values for the layer_types. If the # encoder is empty, do not use the enco-deco attention. self_att = blocks_list[0] or hparams.default_att ende_att = hparams.default_att if layer_types[0] else "_" ff = hparams.default_ff if len(blocks_list) > 1: ff = blocks_list[-1] if len(blocks_list) == 3: ende_att = blocks_list[1] if i == 0: # Encoder blocks_tuple = (self_att, ff) elif i == 1: # Decoder blocks_tuple = (self_att, ende_att, ff) final_layers[i].append(blocks_tuple) return final_layers # ========= Construct the transformer encoder and decoder ========= encoder_layers, decoder_layers = extract_layer_types( hparams.layer_types) # Display the encoder-decoder architecture def print_layer(name, layers): tf.logging.info("{} architecture:".format(name)) for i, l in enumerate(layers): tf.logging.info(" * Layer {}: {}".format(i, " - ".join(l))) print_layer("Encoder", encoder_layers) print_layer("Decoder", decoder_layers) encoder_outputs = [] x = encoder_input with tf.variable_scope("encoder"): for layer_num, block_types in enumerate(encoder_layers): # Each encoder layers is composed of two blocks: # * self-attention block # * feed-forward block att_type, ff_type = block_types with tf.variable_scope("layer_{}".format(layer_num)): with tf.variable_scope("att_{}".format(att_type)): x = available_layers[att_type]( x, memory_antecedent=None, ) with tf.variable_scope("ff_{}".format(ff_type)): x = available_layers[ff_type](x) encoder_outputs.append(x) if encoder_outputs: encoder_outputs[-1] = dp_preprocess(x) x = decoder_input with tf.variable_scope("decoder"): for layer_num, block_types in enumerate(decoder_layers): # Each decoder layers is composed of three blocks: # * self-attention block # * enco-deco attention block (optional) # * feed-forward block self_att_type, att_ende_type, ff_type = block_types with tf.variable_scope("layer_{}".format(layer_num)): with tf.variable_scope( "self_att_{}".format(self_att_type)): x = available_layers[self_att_type]( x, memory_antecedent=None, ) with tf.variable_scope( "att_ende_{}".format(att_ende_type)): # Only add the enco-deco attention layer if there is an encoder if encoder_outputs: x = available_layers[att_ende_type]( x, memory_antecedent=encoder_outputs[-1], ) with tf.variable_scope("ff_{}".format(ff_type)): x = available_layers[ff_type](x) # If normalization is done in layer_preprocess, then it should also be # done on the output, since the output can grow very large, being the sum # of a whole stack of unnormalized layer outputs. x = dp_preprocess(x) decoder_output = dp(tf.expand_dims, x, 2) return decoder_output, cache["extra_loss"]
def model_fn_body_sharded(self, sharded_features): # Remove dropout if not training hparams = self._hparams dp = self._data_parallelism if hparams.use_inputs: decoder_input = dp(tf.squeeze, sharded_features["inputs"], 2) decoder_self_attention_bias = None else: targets = sharded_features["targets"] targets = dp(tf.squeeze, targets, 2) (decoder_input, decoder_self_attention_bias, pad_remover) = dp(attention_lm_moe_prepare_decoder, targets, hparams) def preprocess(x): return dp(common_layers.layer_preprocess, x, hparams) def postprocess(x, y): return dp(common_layers.layer_postprocess, x, y, hparams) x = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) extra_loss = 0.0 moe_hidden_sizes = [ int(s) for s in hparams.moe_hidden_sizes.split(",") ] if hparams.diet_experts: hsize, = moe_hidden_sizes def _diet_expert(x): return diet.diet_expert(x, hsize, diet.diet_adam_optimizer_params()) expert_fn = _diet_expert else: expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) if (hparams.attention_type == AttentionType.LOCAL_EXPERTS and not hparams.use_inputs): # As preprocess and postprocess are called with batch of size one (all # batches concatenated), we just make sure that batch_norm is not use ( # should not either way) assert hparams.norm_type != "batch" tf.logging.info( "Applying Padding Remover for the attention experts") dp_remove_pad = functools.partial(dp, remove_pad, pad_remover=pad_remover, mode=hparams.mode) dp_restore_pad = functools.partial(dp, restore_pad, ref_x=x, pad_remover=pad_remover, mode=hparams.mode) else: # Using identity function: No effect dp_remove_pad = lambda x: x dp_restore_pad = lambda x: x def print_shape(x, suffix, debug=False): # To help debugging, print the input/output shapes at inference and eval # Inference for long sequences can take a long time, so that's help to # see the progession of the generation if not debug and hparams.mode == ModeKeys.TRAIN: return x return tf.Print(x, [tf.shape(x)], "shape_x_{}".format(suffix)) batch_coordinate = dp(get_batch_coordinate, x) batch_coordinate = dp_remove_pad(batch_coordinate) x = dp(print_shape, x, "in") x = dp_remove_pad(x) x = dp(print_shape, x, "in_flat") assert hparams.batch_size >= hparams.max_length for layer in xrange(hparams.num_hidden_layers): with tf.variable_scope("layer_%d" % layer): with tf.variable_scope("attention_{}".format( hparams.attention_type)): if hparams.attention_type == AttentionType.MULTIHEAD: y = dp(common_attention.multihead_attention, preprocess(x), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=("local_mask_right" if hparams.attention_local else "dot_product"), name="decoder_self_attention") elif hparams.attention_type == AttentionType.MEMORY_EFFICIENT: assert hparams.layer_preprocess_sequence == "n" y = dp(common_attention. multihead_self_attention_memory_efficient, x, decoder_self_attention_bias, hparams.num_heads, name="decoder_self_attention") elif hparams.attention_type == AttentionType.LOCAL_EXPERTS: y, loss = dp( common_attention.local_expert_attention, preprocess(x), k=hparams.attention_moe_k, loss_coef=hparams.attention_load_balance, attention_num_experts=hparams. attention_num_experts, train=hparams.mode == ModeKeys.TRAIN, batch_coordinate=batch_coordinate, mask_right=not hparams.use_inputs, split_batch=bool(hparams.attention_split_batch), attention_kq_size=hparams.attention_kq_size, attention_v_size=hparams.attention_v_size) # TODO(avaswani, epot, noam): Do we need to divide by num shards ? extra_loss += tf.add_n(loss) / dp.n else: raise ValueError("Only {} supported for now.".format( AttentionType.get_choices())) x = postprocess(x, y) with tf.variable_scope("ffn"): if str(layer) in hparams.moe_layers.split(","): y, loss = expert_utils.distributed_moe( dp, self._ps_devices, preprocess(x), hparams.mode == ModeKeys.TRAIN, input_size=hparams.hidden_size, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) extra_loss += loss elif hparams.memory_efficient_ffn: assert hparams.layer_preprocess_sequence == "n" y = dp(common_layers.conv_hidden_relu_memory_efficient, x, hparams.filter_size) else: x_in = preprocess(x) additional_conv_params = dict() if hparams.use_sepconv: # Restore padding so sequences don't attend to each others # restore_pad will apply a reshape like x_ref, to restore the # original shape. Here this works because the last dimension is # constant between the output of attention and the original input # but it shouldn't necessarily be the case. x_in = dp_restore_pad(x_in) additional_conv_params = dict( padding="LEFT", # Parameters copied from the transformer model kernel_size=(3, 1), second_kernel_size=(31, 1), ) y = dp(common_layers.conv_hidden_relu, x_in, hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout, **additional_conv_params) if hparams.use_sepconv: y = dp_remove_pad(y) x = postprocess(x, y) x = preprocess(x) x = dp_restore_pad(x) decoder_output = dp(tf.expand_dims, x, 2) return decoder_output, extra_loss