def _super_stack(inputs, attention_bias, hparams, mp, padding="LEFT"): """A stack of super_lm layers. Args: inputs: a list of Tensors attention_bias: list of bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model mp: a Parallelism object padding: a string Returns: y: a list of Tensors extra_loss: an optional scalar """ layers = hparams.layers.strip(",").split(",") moe_hidden_sizes = [int(s) for s in hparams.moe_hidden_sizes.split(",")] if hparams.diet_experts: hsize, = moe_hidden_sizes def _diet_expert(x): return diet.diet_expert(x, hsize, diet.diet_adam_optimizer_params()) expert_fn = _diet_expert else: expert_fn = expert_utils.ffn_expert_fn( hparams.hidden_size, moe_hidden_sizes, hparams.hidden_size) # scaled_dot_product_attention_with_projections uses a 3d attention bias # (no heads), where multihead_attention uses 4d attention bias. attention_bias_3d = mp(tf.squeeze, attention_bias, 1) mix_size = int(hparams.mix_fraction * hparams.hidden_size) accumulator = inputs x = inputs extra_losses = [] for layer_num, layer_type in enumerate(layers): with tf.variable_scope("%s_%d" % (layer_type, layer_num)): tf.logging.info("%s_%d" % (layer_type, layer_num)) if layer_type == "a": # accumulate accumulator = mp(tf.add, x, accumulator) x = accumulator elif layer_type == "n": # normalize x = mp(common_layers.apply_norm, x, hparams.norm_type, hparams.hidden_size, hparams.norm_epsilon) elif layer_type == "d": # dropout x = mp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout) elif layer_type == "m": # mix across shards def _split(t): return tuple(tf.split( t, [mix_size, hparams.hidden_size - mix_size], 2)) to_mix, to_keep = mp(_split, x) mixed = common_layers.all_reduce_ring(to_mix, mp) mixed = mp(tf.multiply, mixed, mp.n ** -0.5) x = mp(lambda a, b: tf.concat([a, b], 2), mixed, to_keep) elif layer_type == "att": # single-head attention q = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False, name="q_transform") x = mp( common_attention.scaled_dot_product_attention_simple, q, x, x, attention_bias_3d) x = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False, name="o_transform") elif layer_type == "multihead-att": # multi-head attention x = mp( common_attention.multihead_attention, x, None, attention_bias, # bias hparams.multihead_attention_key_channels or hparams.hidden_size, hparams.multihead_attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.multihead_attention_num_heads, hparams.attention_dropout) elif layer_type == "ffn": x = mp( common_layers.dense_relu_dense, x, hparams.filter_size, hparams.hidden_size) elif layer_type == "conv": # convolution x = mp( common_layers.conv1d, x, hparams.hidden_size, hparams.kernel_height, activation=tf.nn.relu, padding=padding, ) elif layer_type == "moe": # mixture of experts - each model shard has its own local MoE. x, loss = mp( expert_utils.local_moe, x, train=hparams.mode == tf.estimator.ModeKeys.TRAIN, expert_fn=expert_fn, num_experts=hparams.moe_num_experts, k=hparams.moe_k, loss_coef=hparams.moe_loss_coef) extra_losses.extend(loss) else: assert False, "unknown sublayer %s" % layer_type if extra_losses: extra_loss = tf.add_n(extra_losses) else: extra_loss = None return x, extra_loss
def body(self, features): # Remove dropout if not training hparams = self._hparams ps_devices = self._ps_devices assert hparams.num_model_shards % len(ps_devices) == 0 shards_per_device = hparams.num_model_shards // len(ps_devices) model_devices = [ps_devices[i // shards_per_device] for i in xrange(hparams.num_model_shards)] print("model_devices = %s" % model_devices) mp = expert_utils.Parallelism(model_devices, reuse=False) vocab_size = self._problem_hparams.vocabulary["targets"].vocab_size # squeeze out channels, heights targets = features["targets_raw"] targets = tf.squeeze(targets, 3) targets = tf.squeeze(targets, 2) shifted_targets = common_layers.shift_right_2d(targets) # Bypass the symbol modality and use a different embedding on each shard. decoder_input = mp( common_layers.embedding, shifted_targets, vocab_size, hparams.hidden_size, multiplier=hparams.hidden_size**0.5, symbol_dropout_rate=hparams.symbol_dropout) decoder_self_attention_bias = mp( common_attention.attention_bias_lower_triangle, tf.shape(targets)[1]) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias = mp( tf.add, decoder_self_attention_bias, mp(common_attention.attention_bias_same_segment, targets_segmentation, targets_segmentation)) else: targets_position = None if hparams.pos == "timing": if targets_position is None: decoder_input = mp(common_attention.add_timing_signal_1d, decoder_input) else: decoder_input = mp( common_attention.add_timing_signal_1d_given_position, decoder_input, targets_position) decoder_input = mp( tf.nn.dropout, decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output, extra_loss = _super_stack( decoder_input, decoder_self_attention_bias, hparams, mp) # Bypass the symbol modality and compute logits directly. # We compute a different set of logits on each shard, and sum them. logits = mp(tf.layers.dense, decoder_output, vocab_size, name="logits") logits = common_layers.all_reduce_ring(logits, mp) logits = mp(tf.multiply, logits, mp.n ** -0.5) # We now have identical logits on all shards. # Shard 0 gets returned to the estimator. logits_shard_0 = logits[0] logits_shard_0 = tf.expand_dims(logits_shard_0, 2) logits_shard_0 = tf.expand_dims(logits_shard_0, 3) # On each device, we compute the loss for a part of the batch. # This is faster than computing the whole loss on one shard. mp, logits = common_layers.reduce_by_device(mp, logits, lambda l: l[0]) def _loss_for_shard(logits, targets, shard): if mp.n > 1: logits = common_layers.approximate_split(logits, mp.n, 0)[shard] targets = common_layers.approximate_split(targets, mp.n, 0)[shard] return common_layers.padded_cross_entropy( logits, targets, hparams.label_smoothing) num, denom = mp(_loss_for_shard, logits, targets, range(mp.n)) # override training loss so that it is not computed externally. losses = {"training": tf.add_n(num) / tf.add_n(denom)} if extra_loss is not None: losses["extra"] = extra_loss return logits_shard_0, losses
def _layer_stack(mp, inputs, self_attention_bias, layers, hparams, encoder_output=None, encoder_decoder_attention_bias=None): """A stack of layers. Args: mp: a Parallelism object inputs: a list of Tensors self_attention_bias: list of bias Tensor for self-attention (see common_attention.attention_bias()) layers: a string hparams: hyperparameters for model encoder_output: optional list of tensors encoder_decoder_attention_bias: optional list of tensors Returns: y: a list of Tensors """ layers = layers.strip(",").split(",") # scaled_dot_product_attention_with_projections uses a 3d attention bias # (no heads), where multihead_attention uses 4d attention bias. self_attention_bias_3d = mp(tf.squeeze, self_attention_bias, 1) if encoder_decoder_attention_bias is not None: encoder_decoder_attention_bias_3d = mp(tf.squeeze, encoder_decoder_attention_bias, 1) relu_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "relu_dropout_broadcast_dims", ""))) mix_size = int(hparams.mix_fraction * hparams.hidden_size) accumulator = inputs x = inputs for layer_num, layer_type in enumerate(layers): with tf.variable_scope("%s_%d" % (layer_type, layer_num)): tf.logging.info("%s_%d" % (layer_type, layer_num)) if layer_type == "a": # accumulate accumulator = mp(tf.add, x, accumulator) x = accumulator elif layer_type == "n": # normalize x = mp(common_layers.apply_norm, x, hparams.norm_type, hparams.hidden_size, hparams.norm_epsilon) elif layer_type == "d": # dropout x = mp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout) elif layer_type == "m": if mix_size > 0: # mix across shards def _split(t): return tuple( tf.split( t, [mix_size, hparams.hidden_size - mix_size], 2)) to_mix, to_keep = mp(_split, x) mixed = common_layers.all_reduce_ring(to_mix, mp) mixed = mp(tf.multiply, mixed, mp.n**-0.5) x = mp(lambda a, b: tf.concat([a, b], 2), mixed, to_keep) elif layer_type == "att": # single-head attention q = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False, name="q_transform") x = mp(common_attention.scaled_dot_product_attention_simple, q, x, x, self_attention_bias_3d) x = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False, name="o_transform") elif layer_type == "enc-att": # single-head attention over encoder q = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False, name="q_transform") assert encoder_output is not None x = mp(common_attention.scaled_dot_product_attention_simple, q, encoder_output, encoder_output, encoder_decoder_attention_bias_3d) x = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False, name="o_transform") elif layer_type == "multihead-att": # multi-head attention x = mp( common_attention.multihead_attention, x, None, self_attention_bias, # bias hparams.multihead_attention_key_channels or hparams.hidden_size, hparams.multihead_attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.multihead_attention_num_heads, hparams.attention_dropout) elif layer_type == "enc-multihead-att": # multi-head attention x = mp( common_attention.multihead_attention, x, encoder_output, encoder_decoder_attention_bias, # bias hparams.multihead_attention_key_channels or hparams.hidden_size, hparams.multihead_attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.multihead_attention_num_heads, hparams.attention_dropout) elif layer_type == "ffn": x = mp(common_layers.dense_relu_dense, x, hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout, dropout_broadcast_dims=[relu_dropout_broadcast_dims] * mp.n) else: assert False, "unknown sublayer %s" % layer_type return x
def body(self, features): hparams = self._hparams ps_devices = self._ps_devices single_device = (len(ps_devices) == 1) assert hparams.num_model_shards % len(ps_devices) == 0 shards_per_device = hparams.num_model_shards // len(ps_devices) model_devices = [ ps_devices[i // shards_per_device] for i in range(hparams.num_model_shards) ] print("model_devices = %s" % model_devices) mp = expert_utils.Parallelism(model_devices, reuse=False) targets_vocab_size = self._problem_hparams.vocabulary[ "targets"].vocab_size # squeeze out channels, heights targets = tf.squeeze(features["targets_raw"], [2, 3]) targets_embedding_var = mp( tf.get_variable, "embedding", [[targets_vocab_size, hparams.hidden_size]] * mp.n, initializer=tf.random_normal_initializer( 0.0, hparams.hidden_size**-0.5)) shifted_targets = common_layers.shift_right_2d(targets) # Bypass the symbol modality and use a different embedding on each shard. if single_device: targets_embedding_var_combined = tf.concat(targets_embedding_var, 1) decoder_input_combined = common_layers.embedding( shifted_targets, targets_vocab_size, hparams.hidden_size * mp.n, multiplier=hparams.hidden_size**0.5, embedding_var=targets_embedding_var_combined, ) decoder_input = tf.split(decoder_input_combined, mp.n, axis=2) else: targets_embedding_var_combined = None decoder_input = mp( common_layers.embedding, shifted_targets, targets_vocab_size, hparams.hidden_size, multiplier=hparams.hidden_size**0.5, embedding_var=targets_embedding_var, ) decoder_self_attention_bias = mp( common_attention.attention_bias_lower_triangle, tf.shape(targets)[1]) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias = mp( tf.add, decoder_self_attention_bias, mp(common_attention.attention_bias_same_segment, targets_segmentation, targets_segmentation)) decoder_input = mp( common_attention.add_timing_signal_1d_given_position, decoder_input, targets_position) else: targets_position = None decoder_self_attention_bias = mp( common_attention.attention_bias_lower_triangle, tf.shape(targets)[1]) decoder_input = mp(common_attention.add_timing_signal_1d, decoder_input) if self.has_input: inputs = tf.squeeze(features["inputs_raw"], [2, 3]) inputs_vocab_size = self._problem_hparams.vocabulary[ "inputs"].vocab_size # share everything for now share_inputs_and_targets_embedding = True if share_inputs_and_targets_embedding: assert inputs_vocab_size == targets_vocab_size inputs_embedding_var = targets_embedding_var inputs_embedding_var_combined = targets_embedding_var_combined if single_device: encoder_input_combined = common_layers.embedding( inputs, inputs_vocab_size, hparams.hidden_size * mp.n, multiplier=hparams.hidden_size**0.5, embedding_var=inputs_embedding_var_combined, ) encoder_input = tf.split(encoder_input_combined, mp.n, axis=2) else: encoder_input = mp( common_layers.embedding, inputs, inputs_vocab_size, hparams.hidden_size, multiplier=hparams.hidden_size**0.5, embedding_var=inputs_embedding_var, ) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] encoder_self_attention_bias = mp( common_attention.attention_bias_same_segment, inputs_segmentation, inputs_segmentation) encoder_decoder_attention_bias = mp( common_attention.attention_bias_same_segment, targets_segmentation, inputs_segmentation) encoder_input = mp( common_attention.add_timing_signal_1d_given_position, encoder_input, inputs_position) else: encoder_padding = tf.to_float(tf.equal(inputs, 0)) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None encoder_input = mp(common_attention.add_timing_signal_1d, encoder_input) # encoder stack here with tf.variable_scope("encoder"): encoder_input = mp(tf.nn.dropout, encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = _layer_stack(mp, encoder_input, encoder_self_attention_bias, hparams.encoder_layers, hparams) else: encoder_decoder_attention_bias = None encoder_output = None with tf.variable_scope("decoder"): decoder_input = mp(tf.nn.dropout, decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output = _layer_stack( mp, decoder_input, decoder_self_attention_bias, layers=hparams.decoder_layers, hparams=hparams, encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias) # Bypass the symbol modality and compute logits directly. # We compute a different set of logits on each shard, and sum them. # Share the weights with the target embedding. output_var = targets_embedding_var output_var_combined = targets_embedding_var_combined if single_device: decoder_output = tf.concat(decoder_output, 2) logits = tf.tensordot(decoder_output, output_var_combined, [[2], [1]]) num, denom = common_layers.padded_cross_entropy( logits, targets, hparams.label_smoothing) training_loss = num / denom else: logits = mp(tf.tensordot, decoder_output, output_var, [[[2], [1]]] * mp.n) logits = common_layers.all_reduce_ring(logits, mp) # On each device, we compute the loss for a part of the batch. # This is faster than computing the whole loss on one shard. mp, logits = common_layers.reduce_by_device( mp, logits, lambda l: l[0]) def _loss_for_shard(logits, targets, shard): logits = common_layers.approximate_split(logits, mp.n, 0)[shard] targets = common_layers.approximate_split(targets, mp.n, 0)[shard] return common_layers.padded_cross_entropy( logits, targets, hparams.label_smoothing) num, denom = mp(_loss_for_shard, logits, targets, range(mp.n)) training_loss = tf.add_n(num) / tf.add_n(denom) logits = logits[0] logits = tf.expand_dims(tf.expand_dims(logits, 2), 3) # override training loss so that it is not computed externally. losses = {"training": training_loss} return logits, losses
def _super_stack(inputs, attention_bias, hparams, mp, padding="LEFT"): """A stack of super_lm layers. Args: inputs: a list of Tensors attention_bias: list of bias Tensor for self-attention (see common_attention.attention_bias()) hparams: hyperparameters for model mp: a Parallelism object padding: a string Returns: y: a Tensors """ layers = hparams.layers.strip(",").split(",") ffn_hidden_sizes = [int(s) for s in hparams.ffn_hidden_sizes.split(",")] # scaled_dot_product_attention_with_projections uses a 3d attention bias # (no heads), where multihead_attention uses 4d attention bias. mix_size = int(hparams.mix_fraction * hparams.hidden_size) attention_bias_3d = mp(tf.squeeze, attention_bias, 1) accumulator = inputs x = inputs for layer_num, layer_type in enumerate(layers): with tf.variable_scope("%s_%d" % (layer_type, layer_num)): tf.logging.info("%s_%d" % (layer_type, layer_num)) if layer_type == "a": # accumulate accumulator = mp(tf.add, x, accumulator) x = accumulator elif layer_type == "n": # normalize x = mp(common_layers.apply_norm, x, hparams.norm_type, hparams.hidden_size, hparams.norm_epsilon) elif layer_type == "d": # dropout x = mp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout) elif layer_type == "m": # mix across shards def _split(t): return tuple( tf.split(t, [mix_size, hparams.hidden_size - mix_size], 2)) to_mix, to_keep = mp(_split, x) mixed = common_layers.all_reduce_ring(to_mix, mp) mixed = mp(tf.multiply, mixed, mp.n**-0.5) x = mp(lambda a, b: tf.concat([a, b], 2), mixed, to_keep) elif layer_type == "att": # single-head attention q = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False, name="q_transform") x = mp(common_attention.scaled_dot_product_attention_simple, q, x, x, attention_bias_3d) x = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False, name="o_transform") elif layer_type == "multihead-att": # multi-head attention x = mp( common_attention.multihead_attention, x, None, attention_bias, # bias hparams.attention_key_channels or hparams.hidden_size, hparams.attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout) elif layer_type == "ffn": y = mp( expert_utils.ffn_expert_fn(hparams.hidden_size, ffn_hidden_sizes, hparams.hidden_size), mp(expert_utils.flatten_all_but_last, x)) x = mp(expert_utils.reshape_like, y, x) elif layer_type == "conv": # convolution x = mp( common_layers.conv1d, x, hparams.hidden_size, hparams.kernel_height, activation=tf.nn.relu, padding=padding, ) else: assert False, "unknown sublayer %s" % layer_type return x
def body(self, features): hparams = self._hparams ps_devices = self._ps_devices single_device = (len(ps_devices) == 1) assert hparams.num_model_shards % len(ps_devices) == 0 shards_per_device = hparams.num_model_shards // len(ps_devices) model_devices = [ps_devices[i // shards_per_device] for i in xrange(hparams.num_model_shards)] print("model_devices = %s" % model_devices) mp = expert_utils.Parallelism(model_devices, reuse=False) targets_vocab_size = self._problem_hparams.vocabulary["targets"].vocab_size # squeeze out channels, heights targets = tf.squeeze(features["targets_raw"], [2, 3]) targets_embedding_var = mp( tf.get_variable, "embedding", [[targets_vocab_size, hparams.hidden_size]] * mp.n, initializer=tf.random_normal_initializer( 0.0, hparams.hidden_size**-0.5)) shifted_targets = common_layers.shift_right_2d(targets) # Bypass the symbol modality and use a different embedding on each shard. if single_device: targets_embedding_var_combined = tf.concat(targets_embedding_var, 1) decoder_input_combined = common_layers.embedding( shifted_targets, targets_vocab_size, hparams.hidden_size * mp.n, multiplier=hparams.hidden_size**0.5, embedding_var=targets_embedding_var_combined, ) decoder_input = tf.split(decoder_input_combined, mp.n, axis=2) else: targets_embedding_var_combined = None decoder_input = mp( common_layers.embedding, shifted_targets, targets_vocab_size, hparams.hidden_size, multiplier=hparams.hidden_size**0.5, embedding_var=targets_embedding_var, ) decoder_self_attention_bias = mp( common_attention.attention_bias_lower_triangle, tf.shape(targets)[1]) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias = mp( tf.add, decoder_self_attention_bias, mp(common_attention.attention_bias_same_segment, targets_segmentation, targets_segmentation)) decoder_input = mp( common_attention.add_timing_signal_1d_given_position, decoder_input, targets_position) else: targets_position = None decoder_self_attention_bias = mp( common_attention.attention_bias_lower_triangle, tf.shape(targets)[1]) decoder_input = mp(common_attention.add_timing_signal_1d, decoder_input) if self.has_input: inputs = tf.squeeze(features["inputs_raw"], [2, 3]) inputs_vocab_size = self._problem_hparams.vocabulary["inputs"].vocab_size # share everything for now share_inputs_and_targets_embedding = True if share_inputs_and_targets_embedding: assert inputs_vocab_size == targets_vocab_size inputs_embedding_var = targets_embedding_var inputs_embedding_var_combined = targets_embedding_var_combined if single_device: encoder_input_combined = common_layers.embedding( inputs, inputs_vocab_size, hparams.hidden_size * mp.n, multiplier=hparams.hidden_size**0.5, embedding_var=inputs_embedding_var_combined, ) encoder_input = tf.split(encoder_input_combined, mp.n, axis=2) else: encoder_input = mp( common_layers.embedding, inputs, inputs_vocab_size, hparams.hidden_size, multiplier=hparams.hidden_size**0.5, embedding_var=inputs_embedding_var, ) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] encoder_self_attention_bias = mp( common_attention.attention_bias_same_segment, inputs_segmentation, inputs_segmentation) encoder_decoder_attention_bias = mp( common_attention.attention_bias_same_segment, targets_segmentation, inputs_segmentation) encoder_input = mp( common_attention.add_timing_signal_1d_given_position, encoder_input, inputs_position) else: encoder_padding = tf.to_float(tf.equal(inputs, 0)) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None encoder_input = mp(common_attention.add_timing_signal_1d, encoder_input) # encoder stack here with tf.variable_scope("encoder"): encoder_input = mp( tf.nn.dropout, encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = _layer_stack( mp, encoder_input, encoder_self_attention_bias, hparams.encoder_layers, hparams) else: encoder_decoder_attention_bias = None encoder_output = None with tf.variable_scope("decoder"): decoder_input = mp( tf.nn.dropout, decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output = _layer_stack( mp, decoder_input, decoder_self_attention_bias, layers=hparams.decoder_layers, hparams=hparams, encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias) # Bypass the symbol modality and compute logits directly. # We compute a different set of logits on each shard, and sum them. # Share the weights with the target embedding. output_var = targets_embedding_var output_var_combined = targets_embedding_var_combined if single_device: decoder_output = tf.concat(decoder_output, 2) logits = tf.tensordot(decoder_output, output_var_combined, [[2], [1]]) num, denom = common_layers.padded_cross_entropy( logits, targets, hparams.label_smoothing) training_loss = num / denom else: logits = mp( tf.tensordot, decoder_output, output_var, [[[2], [1]]] * mp.n) logits = common_layers.all_reduce_ring(logits, mp) # On each device, we compute the loss for a part of the batch. # This is faster than computing the whole loss on one shard. mp, logits = common_layers.reduce_by_device(mp, logits, lambda l: l[0]) def _loss_for_shard(logits, targets, shard): logits = common_layers.approximate_split(logits, mp.n, 0)[shard] targets = common_layers.approximate_split(targets, mp.n, 0)[shard] return common_layers.padded_cross_entropy( logits, targets, hparams.label_smoothing) num, denom = mp(_loss_for_shard, logits, targets, range(mp.n)) training_loss = tf.add_n(num) / tf.add_n(denom) logits = logits[0] logits = tf.expand_dims(tf.expand_dims(logits, 2), 3) # override training loss so that it is not computed externally. losses = {"training": training_loss} return logits, losses
def _layer_stack(mp, inputs, self_attention_bias, layers, hparams, encoder_output=None, encoder_decoder_attention_bias=None): """A stack of layers. Args: mp: a Parallelism object inputs: a list of Tensors self_attention_bias: list of bias Tensor for self-attention (see common_attention.attention_bias()) layers: a string hparams: hyperparameters for model encoder_output: optional list of tensors encoder_decoder_attention_bias: optional list of tensors Returns: y: a list of Tensors """ layers = layers.strip(",").split(",") # scaled_dot_product_attention_with_projections uses a 3d attention bias # (no heads), where multihead_attention uses 4d attention bias. self_attention_bias_3d = mp(tf.squeeze, self_attention_bias, 1) if encoder_decoder_attention_bias is not None: encoder_decoder_attention_bias_3d = mp( tf.squeeze, encoder_decoder_attention_bias, 1) relu_dropout_broadcast_dims = ( common_layers.comma_separated_string_to_integer_list( getattr(hparams, "relu_dropout_broadcast_dims", ""))) mix_size = int(hparams.mix_fraction * hparams.hidden_size) accumulator = inputs x = inputs for layer_num, layer_type in enumerate(layers): with tf.variable_scope("%s_%d" % (layer_type, layer_num)): tf.logging.info("%s_%d" % (layer_type, layer_num)) if layer_type == "a": # accumulate accumulator = mp(tf.add, x, accumulator) x = accumulator elif layer_type == "n": # normalize x = mp(common_layers.apply_norm, x, hparams.norm_type, hparams.hidden_size, hparams.norm_epsilon) elif layer_type == "d": # dropout x = mp(tf.nn.dropout, x, 1.0 - hparams.layer_prepostprocess_dropout) elif layer_type == "m": if mix_size > 0: # mix across shards def _split(t): return tuple(tf.split( t, [mix_size, hparams.hidden_size - mix_size], 2)) to_mix, to_keep = mp(_split, x) mixed = common_layers.all_reduce_ring(to_mix, mp) mixed = mp(tf.multiply, mixed, mp.n ** -0.5) x = mp(lambda a, b: tf.concat([a, b], 2), mixed, to_keep) elif layer_type == "att": # single-head attention q = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False, name="q_transform") x = mp( common_attention.scaled_dot_product_attention_simple, q, x, x, self_attention_bias_3d) x = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False, name="o_transform") elif layer_type == "enc-att": # single-head attention over encoder q = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False, name="q_transform") assert encoder_output is not None x = mp( common_attention.scaled_dot_product_attention_simple, q, encoder_output, encoder_output, encoder_decoder_attention_bias_3d) x = mp(tf.layers.dense, x, hparams.hidden_size, use_bias=False, name="o_transform") elif layer_type == "multihead-att": # multi-head attention x = mp( common_attention.multihead_attention, x, None, self_attention_bias, # bias hparams.multihead_attention_key_channels or hparams.hidden_size, hparams.multihead_attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.multihead_attention_num_heads, hparams.attention_dropout) elif layer_type == "enc-multihead-att": # multi-head attention x = mp( common_attention.multihead_attention, x, encoder_output, encoder_decoder_attention_bias, # bias hparams.multihead_attention_key_channels or hparams.hidden_size, hparams.multihead_attention_value_channels or hparams.hidden_size, hparams.hidden_size, hparams.multihead_attention_num_heads, hparams.attention_dropout) elif layer_type == "ffn": x = mp( common_layers.dense_relu_dense, x, hparams.filter_size, hparams.hidden_size, dropout=hparams.relu_dropout, dropout_broadcast_dims=[relu_dropout_broadcast_dims] * mp.n) else: assert False, "unknown sublayer %s" % layer_type return x