def model_fn_body_sharded(self, sharded_features): # Remove dropout if not training hparams = self._hparams dp = self._data_parallelism x = dp(tf.squeeze, sharded_features["inputs"], 2) def preprocess(x): return dp(common_layers.layer_preprocess, x, hparams) def postprocess(x, y): return dp(common_layers.layer_postprocess, x, y, hparams) x = dp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout) extra_loss = 0.0 ffn_hidden_sizes = [ int(s) for s in hparams.ffn_hidden_sizes.split(",") ] moe_hidden_sizes = [ int(s) for s in hparams.moe_hidden_sizes.split(",") ] if hparams.mask_right: def _bias(x): return common_attention.attention_bias_lower_triangle( tf.shape(x)[1]) bias = dp(_bias, x) else: bias = tf.zeros([1, 1, 1, 1]) if hparams.diet_experts: hsize, = moe_hidden_sizes def _diet_expert(x): return diet.diet_expert(x, hsize, diet.diet_adam_optimizer_params()) expert_fn = _diet_expert else: expert_fn = expert_utils.ffn_expert_fn(hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) batch_coordinate = dp(get_batch_coordinate, x) layers = hparams.layers.strip(",").split(",") for layer_num, layer_type in enumerate(layers): with tf.variable_scope("%s_%d" % (layer_type, layer_num)): if _should_preprocess(layer_type): x = preprocess(x) if layer_type == "timing": y = dp(common_attention.add_timing_signal_nd, x) elif layer_type == "pos_emb": y = dp(common_attention.add_positional_embedding_nd, x, hparams.max_length, name="pos_emb") elif layer_type == "att": y = dp( common_attention.multihead_attention, x, None, bias, # bias hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) elif layer_type == "att_grouped": multiplicative_overhead = ( hparams.multiplicative_overhead if hparams.mode == ModeKeys.TRAIN else hparams.multiplicative_overhead_eval) y, loss = dp( common_attention.grouped_attention_multihead, x, x, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, num_groups=hparams.attention_num_groups, memory_target_density=hparams.memory_target_density, multiplicative_overhead=multiplicative_overhead, make_image_summary=hparams.attention_image_summary, mask_right=hparams.mask_right, ) extra_loss += tf.add_n(loss) / dp.n elif layer_type == "att_memory_efficient": assert hparams.layer_preprocess_sequence == "n" y = dp( common_attention. multihead_self_attention_memory_efficient, x, bias, hparams.num_heads) elif layer_type == "att_local": y = dp( common_attention.multihead_attention, x, None, None, # bias hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=("local_mask_right" if hparams.mask_right else "local_unmasked"), block_length=hparams.local_attention_window, block_width=hparams.local_attention_window) elif layer_type == "att_pseudolocal": # This is an inefficient implementation of local attention, for the # purpose of testing model quality. def _pseudolocal_bias(x): return common_attention.attention_bias_local( tf.shape(x)[1], hparams.local_attention_window, 0 if hparams.mask_right else hparams.local_attention_window) pseudolocal_bias = dp(_pseudolocal_bias, x) y = dp( common_attention.multihead_attention, x, None, pseudolocal_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) elif layer_type == "att_local_expert": y, loss = dp( common_attention.local_expert_attention, x, k=hparams.attention_moe_k, loss_coef=hparams.attention_load_balance, attention_num_experts=hparams.attention_num_experts, train=hparams.mode == ModeKeys.TRAIN, batch_coordinate=batch_coordinate, mask_right=hparams.mask_right, split_batch=bool(hparams.attention_split_batch), attention_kq_size=hparams.attention_kq_size, attention_v_size=hparams.attention_v_size) # TODO(avaswani, epot, noam): Do we need to divide by num shards ? extra_loss += tf.add_n(loss) / dp.n elif layer_type == "att_lsh": if hparams.lsh_truncated: attention_fn = common_attention.multihead_attention_sparse_truncated else: attention_fn = common_attention.multihead_attention_sparse_dot_prod y, loss = dp( attention_fn, x, None, None, # Bias is computed inside hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, # Additional parameters bi=[ common_attention.BatchInfo( coordinates=batch_coordinate[i], order=None, # No future mask ) for i in range(dp.n) ], use_map_fn=False, experts_params=dict(nb_hyperplanes=4, )) extra_loss += tf.add_n(loss) / dp.n elif layer_type == "moe": y, loss = expert_utils.distributed_moe( dp, self._ps_devices, x, hparams.mode == ModeKeys.TRAIN, input_size=hparams.hidden_size, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) extra_loss += loss elif layer_type == "ffn": y = dp( expert_utils.ffn_expert_fn(hparams.hidden_size, ffn_hidden_sizes, hparams.hidden_size), dp(expert_utils.flatten_all_but_last, x)) y = dp(expert_utils.reshape_like, y, x) elif layer_type == "conv": y = dp( common_layers.conv1d, x, hparams.hidden_size, hparams.kernel_height, activation=tf.nn.relu, padding="SAME", ) else: assert False, "unknown sublayer %s" % layer_type if _should_postprocess(layer_type): x = postprocess(x, y) else: x = y x = preprocess(x) decoder_output = dp(tf.expand_dims, x, 2) return decoder_output, extra_loss
def body_sharded(self, sharded_features): # Remove dropout if not training hparams = self._hparams dp = self._data_parallelism if hparams.use_inputs: decoder_input = dp(tf.squeeze, sharded_features["inputs"], 2) decoder_self_attention_bias = None else: targets = sharded_features["targets"] targets = dp(tf.squeeze, targets, 2) (decoder_input, decoder_self_attention_bias, pad_remover) = dp(attention_lm_moe_prepare_decoder, targets, hparams) def preprocess(x): return dp(common_layers.layer_preprocess, x, hparams) def postprocess(x, y): return dp(common_layers.layer_postprocess, x, y, hparams) x = dp(tf.nn.dropout, decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) extra_loss = 0.0 if not hparams.use_inputs: # As preprocess and postprocess are called with batch of size one (all # batches concatenated), we just make sure that batch_norm is not use ( # should not either way) assert hparams.norm_type != "batch" tf.logging.info( "Applying Padding Remover for the attention experts") dp_remove_pad = functools.partial(dp, remove_pad, pad_remover=pad_remover, mode=hparams.mode) dp_restore_pad = functools.partial(dp, restore_pad, ref_x=x, pad_remover=pad_remover, mode=hparams.mode) else: # Using identity function: No effect dp_remove_pad = lambda x: x dp_restore_pad = lambda x: x if hparams.attention_exp_factor != 0: tf.logging.info( "Expand/compress tokens before sending them to experts") dp_expand_bc = lambda x: dp( # pylint: disable=g-long-lambda expand_batch_coordinates, x, hparams.attention_exp_factor) dp_expand_x = lambda x: dp( # pylint: disable=g-long-lambda common_attention.deconv_elems_1d, x, hparams. attention_exp_factor, hparams.attention_exp_inputdim) dp_compress_x = lambda x, l: dp( # pylint: disable=g-long-lambda common_attention.conv_elems_1d, x, hparams. attention_exp_factor, l) else: dp_expand_bc = lambda x: x dp_expand_x = lambda x: x dp_compress_x = lambda x, l: x def print_shape(x, suffix, debug=False): # To help debugging, print the input/output shapes at inference and eval # Inference for long sequences can take a long time, so that's help to # see the progression of the generation if not debug and hparams.mode == ModeKeys.TRAIN: return x return tf.Print(x, [tf.shape(x)], "shape_x_{}".format(suffix)) with tf.name_scope("batch_coordinate_preprocess"): batch_coordinate = dp(get_batch_coordinate, x) batch_coordinate = dp_remove_pad(batch_coordinate) batch_coordinate = dp_expand_bc(batch_coordinate) batch_order = dp(get_batch_coordinate, x, axis=-1) batch_order = dp_remove_pad(batch_order) batch_order = dp_expand_bc(batch_order) x = dp(print_shape, x, "in") assert hparams.batch_size >= hparams.max_length num_hidden_layers = (len(hparams.attention_layers) or hparams.num_hidden_layers) for layer in range(num_hidden_layers): with tf.variable_scope("layer_%d" % layer): # Use the layer type defined in attention_layers if hparams.attention_layers: attention_type = LAYER_SYMBOLS[ hparams.attention_layers[layer]] else: attention_type = hparams.attention_type with tf.variable_scope("attention_{}".format(attention_type)): if attention_type in [ AttentionType.MULTIHEAD, AttentionType.MULTIHEAD_FULL ]: attention_dot_type = ("local_mask_right" if hparams.attention_local else "dot_product") if attention_type == AttentionType.MULTIHEAD_FULL: attention_dot_type = "dot_product" y = dp(common_attention.multihead_attention, preprocess(x), None, decoder_self_attention_bias, hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, attention_type=attention_dot_type, block_length=hparams.attention_block_length, name="decoder_self_attention") elif attention_type == AttentionType.SPARSE_MULTIHEAD: x_in = preprocess(x) x_in = dp_remove_pad(x_in) y, loss_experts = dp( common_attention. multihead_attention_sparse_dot_prod, x_in, None, None, # Bias is computed inside hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, # Additional parameters bi=[ common_attention.BatchInfo( coordinates=batch_coordinate[i], order=batch_order[i], # No future mask ) for i in range(dp.n) ], use_map_fn=hparams.lsh_use_map_fn, experts_params=dict( nb_hyperplanes=hparams.lsh_num_hyperplanes, ), ) y = dp_restore_pad(y) # TODO(avaswani, epot, noam): Do we need to divide by num shards ? extra_loss += tf.add_n(loss_experts) / dp.n elif attention_type == AttentionType.SPARSE_MULTIHEAD_TRUNCATED: x_in = preprocess(x) y, loss_experts = dp( common_attention. multihead_attention_sparse_truncated, x_in, None, None, # Bias is computed inside hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, # Additional parameters bi=[ common_attention.BatchInfo( coordinates=batch_coordinate[i], order=batch_order[i], # No future mask ) for i in range(dp.n) ], mask_right=True, experts_params=dict( nb_hyperplanes=hparams.lsh_num_hyperplanes, ), ) # TODO(avaswani, epot, noam): Do we need to divide by num shards ? extra_loss += tf.add_n(loss_experts) / dp.n elif attention_type == AttentionType.MEMORY_EFFICIENT: assert hparams.layer_preprocess_sequence == "n" y = dp(common_attention. multihead_self_attention_memory_efficient, x, decoder_self_attention_bias, hparams.num_heads, name="decoder_self_attention") elif attention_type == AttentionType.MULTIHEAD_REDUCED: y = dp( common_attention.multihead_self_attention_reduced, preprocess(x), factor=hparams.attention_red_factor, reduction_type=hparams.attention_reduction_type, nonlinearity=hparams.attention_nonlinearity, multihead_params=dict( total_key_depth=hparams.attention_key_channels or hparams.hidden_size, total_value_depth=hparams. attention_value_channels or hparams.hidden_size, num_heads=hparams.num_heads, dropout_rate=hparams.attention_dropout, )) elif attention_type == AttentionType.LOCAL_EXPERTS: x_in = preprocess(x) x_in = dp_remove_pad(x_in) x_in = dp_expand_x(x_in) y, loss = dp( common_attention.local_expert_attention, x_in, k=hparams.attention_moe_k, loss_coef=hparams.attention_load_balance, attention_num_experts=hparams. attention_num_experts, train=hparams.mode == ModeKeys.TRAIN, batch_coordinate=batch_coordinate, mask_right=not hparams.use_inputs, split_batch=bool(hparams.attention_split_batch), attention_num_head=hparams.attention_num_head, attention_kq_size=hparams.attention_kq_size, attention_v_size=hparams.attention_v_size) y = dp_compress_x(y, x[0].get_shape().as_list()[-1]) y = dp_restore_pad(y) # TODO(avaswani, epot, noam): Do we need to divide by num shards ? extra_loss += tf.add_n(loss) / dp.n else: raise ValueError("Only {} supported for now.".format( AttentionType.get_choices())) x = postprocess(x, y) with tf.variable_scope("ffn"): if hparams.memory_efficient_ffn: assert hparams.layer_preprocess_sequence == "n" y = dp(common_layers.conv_hidden_relu_memory_efficient, x, hparams.filter_size) else: additional_conv_params = {} if hparams.use_sepconv: additional_conv_params = dict( padding="LEFT", # Parameters copied from the transformer model kernel_size=(3, 1), second_kernel_size=(31, 1), ) y = dp(common_layers.conv_hidden_relu, preprocess(x), hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout, **additional_conv_params) x = postprocess(x, y) x = preprocess(x) decoder_output = dp(tf.expand_dims, x, 2) return decoder_output, extra_loss