def body(self, features): # Remove dropout if not training hparams = self._hparams ps_devices = self._ps_devices assert hparams.num_model_shards % len(ps_devices) == 0 shards_per_device = hparams.num_model_shards // len(ps_devices) model_devices = [ ps_devices[i // shards_per_device] for i in range(hparams.num_model_shards) ] print("model_devices = %s" % model_devices) mp = expert_utils.Parallelism(model_devices, reuse=False) vocab_size = self._problem_hparams.vocabulary["targets"].vocab_size # squeeze out channels, heights targets = features["targets_raw"] targets = tf.squeeze(targets, 3) targets = tf.squeeze(targets, 2) shifted_targets = common_layers.shift_right_2d(targets) # Bypass the symbol modality and use a different embedding on each shard. decoder_input = mp(common_layers.embedding, shifted_targets, vocab_size, hparams.hidden_size, multiplier=hparams.hidden_size**0.5, symbol_dropout_rate=hparams.symbol_dropout) decoder_self_attention_bias = mp( common_attention.attention_bias_lower_triangle, tf.shape(targets)[1]) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias = mp( tf.add, decoder_self_attention_bias, mp(common_attention.attention_bias_same_segment, targets_segmentation, targets_segmentation)) else: targets_position = None if hparams.pos == "timing": if targets_position is None: decoder_input = mp(common_attention.add_timing_signal_1d, decoder_input) else: decoder_input = mp( common_attention.add_timing_signal_1d_given_position, decoder_input, targets_position) decoder_input = mp(tf.nn.dropout, decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output, extra_loss = _super_stack(decoder_input, decoder_self_attention_bias, hparams, mp) # Bypass the symbol modality and compute logits directly. # We compute a different set of logits on each shard, and sum them. logits = mp(tf.layers.dense, decoder_output, vocab_size, name="logits") logits = expert_utils.all_reduce_ring(logits, mp) logits = mp(tf.multiply, logits, mp.n**-0.5) # We now have identical logits on all shards. # Shard 0 gets returned to the estimator. logits_shard_0 = logits[0] logits_shard_0 = tf.expand_dims(logits_shard_0, 2) logits_shard_0 = tf.expand_dims(logits_shard_0, 3) # On each device, we compute the loss for a part of the batch. # This is faster than computing the whole loss on one shard. mp, logits = expert_utils.reduce_by_device(mp, logits, lambda l: l[0]) def _loss_for_shard(logits, targets, shard): if mp.n > 1: logits = common_layers.approximate_split(logits, mp.n, 0)[shard] targets = common_layers.approximate_split(targets, mp.n, 0)[shard] return common_layers.padded_cross_entropy(logits, targets, hparams.label_smoothing) num, denom = mp(_loss_for_shard, logits, targets, range(mp.n)) # override training loss so that it is not computed externally. losses = {"training": tf.add_n(num) / tf.add_n(denom)} if extra_loss is not None: losses["extra"] = extra_loss return logits_shard_0, losses
def body(self, features): hparams = self._hparams ps_devices = self._ps_devices single_device = (len(ps_devices) == 1) assert hparams.num_model_shards % len(ps_devices) == 0 shards_per_device = hparams.num_model_shards // len(ps_devices) model_devices = [ps_devices[i // shards_per_device] for i in range(hparams.num_model_shards)] print("model_devices = %s" % model_devices) mp = expert_utils.Parallelism(model_devices, reuse=False) targets_vocab_size = self._problem_hparams.vocabulary["targets"].vocab_size # squeeze out channels, heights targets = tf.squeeze(features["targets_raw"], [2, 3]) targets_embedding_var = mp( tf.get_variable, "embedding", [[targets_vocab_size, hparams.hidden_size]] * mp.n, initializer=tf.random_normal_initializer( 0.0, hparams.hidden_size**-0.5)) shifted_targets = common_layers.shift_right_2d(targets) # Bypass the symbol modality and use a different embedding on each shard. if single_device: targets_embedding_var_combined = tf.concat(targets_embedding_var, 1) decoder_input_combined = common_layers.embedding( shifted_targets, targets_vocab_size, hparams.hidden_size * mp.n, multiplier=hparams.hidden_size**0.5, embedding_var=targets_embedding_var_combined, ) decoder_input = tf.split(decoder_input_combined, mp.n, axis=2) else: targets_embedding_var_combined = None decoder_input = mp( common_layers.embedding, shifted_targets, targets_vocab_size, hparams.hidden_size, multiplier=hparams.hidden_size**0.5, embedding_var=targets_embedding_var, ) decoder_self_attention_bias = mp( common_attention.attention_bias_lower_triangle, tf.shape(targets)[1]) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias = mp( tf.add, decoder_self_attention_bias, mp(common_attention.attention_bias_same_segment, targets_segmentation, targets_segmentation)) decoder_input = mp( common_attention.add_timing_signal_1d_given_position, decoder_input, targets_position) else: targets_position = None decoder_self_attention_bias = mp( common_attention.attention_bias_lower_triangle, tf.shape(targets)[1]) decoder_input = mp(common_attention.add_timing_signal_1d, decoder_input) if self.has_input: inputs = tf.squeeze(features["inputs_raw"], [2, 3]) inputs_vocab_size = self._problem_hparams.vocabulary["inputs"].vocab_size # share everything for now share_inputs_and_targets_embedding = True if share_inputs_and_targets_embedding: assert inputs_vocab_size == targets_vocab_size inputs_embedding_var = targets_embedding_var inputs_embedding_var_combined = targets_embedding_var_combined if single_device: encoder_input_combined = common_layers.embedding( inputs, inputs_vocab_size, hparams.hidden_size * mp.n, multiplier=hparams.hidden_size**0.5, embedding_var=inputs_embedding_var_combined, ) encoder_input = tf.split(encoder_input_combined, mp.n, axis=2) else: encoder_input = mp( common_layers.embedding, inputs, inputs_vocab_size, hparams.hidden_size, multiplier=hparams.hidden_size**0.5, embedding_var=inputs_embedding_var, ) if "inputs_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. inputs_segmentation = features["inputs_segmentation"] inputs_position = features["inputs_position"] encoder_self_attention_bias = mp( common_attention.attention_bias_same_segment, inputs_segmentation, inputs_segmentation) encoder_decoder_attention_bias = mp( common_attention.attention_bias_same_segment, targets_segmentation, inputs_segmentation) encoder_input = mp( common_attention.add_timing_signal_1d_given_position, encoder_input, inputs_position) else: encoder_padding = tf.to_float(tf.equal(inputs, 0)) ignore_padding = common_attention.attention_bias_ignore_padding( encoder_padding) encoder_self_attention_bias = ignore_padding encoder_decoder_attention_bias = ignore_padding inputs_position = None encoder_input = mp(common_attention.add_timing_signal_1d, encoder_input) # encoder stack here with tf.variable_scope("encoder"): encoder_input = mp( tf.nn.dropout, encoder_input, 1.0 - hparams.layer_prepostprocess_dropout) encoder_output = _layer_stack( mp, encoder_input, encoder_self_attention_bias, hparams.encoder_layers, hparams) else: encoder_decoder_attention_bias = None encoder_output = None with tf.variable_scope("decoder"): decoder_input = mp( tf.nn.dropout, decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output = _layer_stack( mp, decoder_input, decoder_self_attention_bias, layers=hparams.decoder_layers, hparams=hparams, encoder_output=encoder_output, encoder_decoder_attention_bias=encoder_decoder_attention_bias) # Bypass the symbol modality and compute logits directly. # We compute a different set of logits on each shard, and sum them. # Share the weights with the target embedding. output_var = targets_embedding_var output_var_combined = targets_embedding_var_combined if single_device: decoder_output = tf.concat(decoder_output, 2) logits = tf.tensordot(decoder_output, output_var_combined, [[2], [1]]) num, denom = common_layers.padded_cross_entropy( logits, targets, hparams.label_smoothing) training_loss = num / denom else: logits = mp( tf.tensordot, decoder_output, output_var, [[[2], [1]]] * mp.n) logits = expert_utils.all_reduce_ring(logits, mp) # On each device, we compute the loss for a part of the batch. # This is faster than computing the whole loss on one shard. mp, logits = expert_utils.reduce_by_device(mp, logits, lambda l: l[0]) def _loss_for_shard(logits, targets, shard): logits = common_layers.approximate_split(logits, mp.n, 0)[shard] targets = common_layers.approximate_split(targets, mp.n, 0)[shard] return common_layers.padded_cross_entropy( logits, targets, hparams.label_smoothing) num, denom = mp(_loss_for_shard, logits, targets, range(mp.n)) training_loss = tf.add_n(num) / tf.add_n(denom) logits = logits[0] logits = tf.expand_dims(tf.expand_dims(logits, 2), 3) # override training loss so that it is not computed externally. losses = {"training": training_loss} return logits, losses
def body(self, features): # Remove dropout if not training hparams = self._hparams ps_devices = self._ps_devices assert hparams.num_model_shards % len(ps_devices) == 0 shards_per_device = hparams.num_model_shards // len(ps_devices) model_devices = [ps_devices[i // shards_per_device] for i in range(hparams.num_model_shards)] print("model_devices = %s" % model_devices) mp = expert_utils.Parallelism(model_devices, reuse=False) vocab_size = self._problem_hparams.vocabulary["targets"].vocab_size # squeeze out channels, heights targets = features["targets_raw"] targets = tf.squeeze(targets, 3) targets = tf.squeeze(targets, 2) shifted_targets = common_layers.shift_right_2d(targets) # Bypass the symbol modality and use a different embedding on each shard. decoder_input = mp( common_layers.embedding, shifted_targets, vocab_size, hparams.hidden_size, multiplier=hparams.hidden_size**0.5, symbol_dropout_rate=hparams.symbol_dropout) decoder_self_attention_bias = mp( common_attention.attention_bias_lower_triangle, tf.shape(targets)[1]) if "targets_segmentation" in features: # "Packed" dataset - keep the examples from seeing each other. targets_segmentation = features["targets_segmentation"] targets_position = features["targets_position"] decoder_self_attention_bias = mp( tf.add, decoder_self_attention_bias, mp(common_attention.attention_bias_same_segment, targets_segmentation, targets_segmentation)) else: targets_position = None if hparams.pos == "timing": if targets_position is None: decoder_input = mp(common_attention.add_timing_signal_1d, decoder_input) else: decoder_input = mp( common_attention.add_timing_signal_1d_given_position, decoder_input, targets_position) decoder_input = mp( tf.nn.dropout, decoder_input, 1.0 - hparams.layer_prepostprocess_dropout) decoder_output, extra_loss = _super_stack( decoder_input, decoder_self_attention_bias, hparams, mp) # Bypass the symbol modality and compute logits directly. # We compute a different set of logits on each shard, and sum them. logits = mp(tf.layers.dense, decoder_output, vocab_size, name="logits") logits = expert_utils.all_reduce_ring(logits, mp) logits = mp(tf.multiply, logits, mp.n ** -0.5) # We now have identical logits on all shards. # Shard 0 gets returned to the estimator. logits_shard_0 = logits[0] logits_shard_0 = tf.expand_dims(logits_shard_0, 2) logits_shard_0 = tf.expand_dims(logits_shard_0, 3) # On each device, we compute the loss for a part of the batch. # This is faster than computing the whole loss on one shard. mp, logits = expert_utils.reduce_by_device(mp, logits, lambda l: l[0]) def _loss_for_shard(logits, targets, shard): if mp.n > 1: logits = common_layers.approximate_split(logits, mp.n, 0)[shard] targets = common_layers.approximate_split(targets, mp.n, 0)[shard] return common_layers.padded_cross_entropy( logits, targets, hparams.label_smoothing) num, denom = mp(_loss_for_shard, logits, targets, range(mp.n)) # override training loss so that it is not computed externally. losses = {"training": tf.add_n(num) / tf.add_n(denom)} if extra_loss is not None: losses["extra"] = extra_loss return logits_shard_0, losses