def body_fn(position, ids, *states): """One step in the decode loop.""" context_incremental = Context( mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, model_dim=self.model_dim, variable_dtype=variable_dtype, mode="incremental", autoregressive=self.autoregressive, position=position, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, layout=self.layout, mesh_shape=self.mesh_shape, encoder_layer_outputs=encoder_layer_outputs) inputs_this_step = mtf.gather(ids, position - 1, length_dim) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step) ids_this_step = mtf.sample_with_temperature( logits, self.output_vocab_dim, temperature) new_position = position + 1 new_ids = ids + ids_this_step * mtf.one_hot( position, length_dim, dtype=tf.int32) return [new_position, new_ids] + context_incremental.new_states
def _noisy_targets_from_spec(self, targets, noising_spec, losses=None): if noising_spec["type"] == "mask": # Replace a randomly-chosen noising_spec["prob"] of input tokens with 0. return targets * mtf.cast( mtf.greater(mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"]), targets.dtype) elif noising_spec["type"] == "random_zipfian": # Replace a randomly-chosen noising_spec["prob"] of input tokens. # Rather than drawing the replacement tokens uniformly, we sample from # a distribution favoring lower token-ids, assuming that the ids have # been assigned in frequency order. The probability of choosing an # id is proportional to 1/(id+10) logits = mtf.log(1.0 / (mtf.range( targets.mesh, self.targets_vocab_dim, dtype=tf.float32) + 10.0)) logits = mtf.broadcast(logits, new_shape=targets.shape + logits.shape) r = mtf.sample_with_temperature(logits, self.targets_vocab_dim) use_noise = mtf.less( mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"]) return mtf.where(use_noise, r, targets) elif noising_spec["type"] == "transformer": # Train a small transformer to fill in masked out values, then # sample from it. hparams = self._hparams if hparams.mode != tf.estimator.ModeKeys.TRAIN: raise NotImplementedError("Not implemented") noiser_hparams = copy.copy(self._hparams) noiser_hparams.del_hparam("mode") noiser_hparams.override_from_dict(noising_spec["overrides"]) with tf.variable_scope("noiser"): noiser = MtfTransformer(noiser_hparams, mode=hparams.mode, problem_hparams=self._problem_hparams) logits, loss = noiser._mtf_model_fn( # pylint: disable=protected-access self._original_features, targets.mesh) samples = mtf.sample_with_temperature(logits, self.targets_vocab_dim) losses.append(loss) return samples else: raise ValueError("unknown noising spec %s" % noising_spec)
def _noisy_targets_from_spec(self, targets, noising_spec, losses=None): if noising_spec["type"] == "mask": # Replace a randomly-chosen noising_spec["prob"] of input tokens with 0. return targets * mtf.cast( mtf.greater(mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"]), targets.dtype) elif noising_spec["type"] == "random_zipfian": # Replace a randomly-chosen noising_spec["prob"] of input tokens. # Rather than drawing the replacement tokens uniformly, we sample from # a distribution favoring lower token-ids, assuming that the ids have # been assigned in frequency order. The probability of choosing an # id is proportional to 1/(id+10) logits = mtf.log(1.0 / (mtf.range( targets.mesh, self.targets_vocab_dim, dtype=tf.float32) + 10.0)) logits = mtf.broadcast(logits, new_shape=targets.shape + logits.shape) r = mtf.sample_with_temperature(logits, self.targets_vocab_dim) use_noise = mtf.less( mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"]) return mtf.where(use_noise, r, targets) elif noising_spec["type"] == "transformer": # Train a small transformer to fill in masked out values, then # sample from it. hparams = self._hparams if hparams.mode != tf.estimator.ModeKeys.TRAIN: raise NotImplementedError("Not implemented") noiser_hparams = copy.copy(self._hparams) noiser_hparams.del_hparam("mode") noiser_hparams.override_from_dict(noising_spec["overrides"]) with tf.variable_scope("noiser"): noiser = MtfTransformer( noiser_hparams, mode=hparams.mode, problem_hparams=self._problem_hparams) logits, loss = noiser._mtf_model_fn( # pylint: disable=protected-access self._original_features, targets.mesh) samples = mtf.sample_with_temperature(logits, self.targets_vocab_dim) losses.append(loss) return samples else: raise ValueError("unknown noising spec %s" % noising_spec)
def body_fn(position, ids, *states): """One step in the decode loop.""" nonlocal sampling_keep_top_k context = mtf_transformer.transformer.Context( model=None, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=position, position_is_default=True, states=states, new_states=[], initial_position=position, sequence_id=None, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=ids, encoder_inputs=encoder_inputs) if not slow_sampling else None with tf.variable_scope("gpt2", reuse=tf.AUTO_REUSE): logits, _, _ = gpt2.model({"inputs": ids}, other_features, params, inputs.mesh, variable_dtype=variable_dtype, context=context) # By default, do top_k sampling of 0.9 if sampling_keep_top_k == -2: sampling_keep_top_k = int(logits.shape[-1].size * 0.1) if sampling_keep_top_k != -1: if sampling_keep_top_k <= 0: raise ValueError( "sampling_keep_top_k must either be -1 or positive.") k_largest = mtf.nth_largest_element( logits, n=sampling_keep_top_k, reduced_dim=other_features["vocab_dim"]) logits = mtf.where(mtf.less_equal(logits, k_largest), mtf.ones_like(logits) * -1e6, logits) ids_this_step = mtf.sample_with_temperature( logits, other_features["vocab_dim"], temperature) if slow_sampling: ids_this_step = mtf.shift(ids_this_step, offset=1, dim=length_dim, wrap=False) else: ids_this_step = mtf.reshape(ids_this_step, (batch_dims)) one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32) one_new_id = ids_this_step * one_hot new_ids = (1 - one_hot) * ids + one_new_id new_position = position + 1 ret = [new_position, new_ids] if context is not None: ret += context.new_states return ret
def _rand_1_gating( inputs, outer_expert_dims, experts_dim, expert_capacity_dim, hparams, train, variable_dtype, importance=None, name="rand_1_gating", num_microbatches=None): """Compute a random top-1 gating.""" # SELECT EXPERT if train: policy = hparams.moe_rand_1_policy_train else: policy = hparams.moe_rand_1_policy_eval # The internals of this function run in float32. # bfloat16 seems to reduce quality. gate_inputs = mtf.to_float(inputs) # Input perturbations if train and policy == "input_dropout": gate_inputs = mtf.dropout(gate_inputs, 1.0 - hparams.moe_rand_1_dropout) elif train and policy == "input_jitter": gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs, hparams.moe_rand_1_jitter) gate_logits = mtf.layers.dense( gate_inputs, experts_dim, use_bias=False, expert_dims=outer_expert_dims, variable_dtype=variable_dtype, name=name) raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim) if policy == "argmax" or policy == "input_dropout" or policy == "input_jitter": expert_gate, expert_index = mtf.top_1(raw_gates, reduced_dim=experts_dim) elif policy == "sample": expert_index = mtf.sample_with_temperature( gate_logits, experts_dim, temperature=hparams.moe_rand_1_temperature) expert_gate = mtf.gather(raw_gates, expert_index, dim=experts_dim) else: raise ValueError("Unknown rand_1 policy %s" % policy) expert_mask = mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype) # LOAD BALANCING LOSS # TODO(liamfedus): Check entropy loss. group_size_dim = inputs.shape[-2] density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim) density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim) if importance is not None: expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype) expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype) density_1_proxy *= mtf.cast( mtf.equal(importance, 1.0), dtype=raw_gates.dtype) loss = ( mtf.reduce_mean(density_1_proxy * density_1) * float(experts_dim.size * experts_dim.size)) if num_microbatches and num_microbatches > 1: tf.logging.info("Dividing load-balance loss by num_microbatches={}".format( num_microbatches)) loss /= num_microbatches # Logging if train: entropy = mtf.reduce_sum(-raw_gates * mtf.log(raw_gates + 1e-9), reduced_dim=experts_dim) batch_entropy = mtf.reduce_mean(entropy) mtf.scalar_summary(name + "/entropy", batch_entropy) mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim]) total_routed = mtf.reduce_sum(mask_count_experts) expert_fraction = mtf.to_float(mask_count_experts / total_routed) split_fractions = mtf.split( expert_fraction, split_dim=experts_dim, num_or_size_splits=experts_dim.size) for fraction in split_fractions: mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"), mtf.reduce_mean(fraction)) mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss)) # COMPUTE ASSIGNMENT TO EXPERT # Experts have a limited capacity, ensure we do not exceed it. Construct # the batch indices, to each expert, with position_in_expert position_in_expert = mtf.cumsum( expert_mask, group_size_dim, exclusive=True) * expert_mask position_in_expert = mtf.cast(position_in_expert, dtype=raw_gates.dtype) # Keep only tokens that fit within expert_capacity. expert_capacity_float = float(expert_capacity_dim.size) expert_mask *= mtf.cast( mtf.less(position_in_expert, expert_capacity_float), dtype=raw_gates.dtype) expert_mask_flat = mtf.reduce_sum(expert_mask, reduced_dim=experts_dim) # Mask out the experts that have overflowed expert capacity. Sparsify the # expert_gate. expert_gate *= expert_mask_flat combine_tensor = ( expert_gate * expert_mask_flat * mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype) * mtf.one_hot( mtf.to_int32(position_in_expert), expert_capacity_dim, dtype=raw_gates.dtype)) # Match the inputs dtype. combine_tensor = mtf.cast(combine_tensor, inputs.dtype) loss = mtf.cast(loss, inputs.dtype) dispatch_tensor = mtf.cast( mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype) return dispatch_tensor, combine_tensor, loss
def body_fn(position, ids, *states): """One step in the decode loop.""" inputs_this_step = mtf.gather(ids, position - 1, length_dim) if self.attribute_embedding: attributes_this_step = mtf.gather(attributes, position - 1, length_dim) else: attributes_this_step = None # raise ValueError("inputs_this_step shape=%s , ids shape=%s, position - 1 shape=%s, length_dim=%s" % (inputs_this_step.shape, ids.shape, (position - 1).shape, length_dim)) context_incremental = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=position, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=position, inputs=inputs_this_step, encoder_inputs=encoder_inputs) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step, attributes=attributes_this_step, z=z) if never_end: logits += mtf.one_hot(mtf.constant(logits.mesh, stop_at_token, dtype=tf.int32), self.output_vocab_dim, on_value=-1e9, off_value=0.0, dtype=logits.dtype) # TBD whether this should be before or after never_end: # Note for adding top_p sampling in the future, in other code bases, the # option to apply temperature is done before the top-k truncation. This # implementation does this in the opposite order. For top-k this doesn't # matter, but for top_p it will. if sampling_keep_top_k != -1: if sampling_keep_top_k <= 0: raise ValueError( "sampling_keep_top_k must either be -1 or positive.") k_largest = mtf.nth_largest_element( logits, n=sampling_keep_top_k, reduced_dim=self.output_vocab_dim) logits = mtf.where(mtf.less_equal(logits, k_largest), mtf.ones_like(logits) * -1e6, logits) ids_this_step = mtf.sample_with_temperature( logits, self.output_vocab_dim, temperature) new_position = position + 1 new_ids = ids + ids_this_step * mtf.one_hot( position, length_dim, dtype=tf.int32) return [new_position, new_ids] + context_incremental.new_states