def restore_pad(x, ref_x, pad_remover, mode): x = tf.squeeze(x, axis=0) if mode != ModeKeys.PREDICT: x = pad_remover.restore(x) x = common_layers.reshape_like(x, ref_x) return x
def local_moe(x, train, expert_fn, num_experts, k=1, loss_coef=1e-2, hparams=None, pass_x=True, pass_gates=False, additional_dispatch_params=None, name=None): """Call a local mixture of experts. Args: x: a tensors with shape [... , input_size] train: a boolean scalar. expert_fn: a function. num_experts: an integer - number of experts k: an integer - how many experts to use for each batch element loss_coef: a scalar - multiplier on load-balancing losses hparams: optional hparams for vq gating pass_x: a boolean. If true, x will also be dispatched to the experts. pass_gates: a boolean. If true, gates will be passed to experts. Might be necessary when dealing with sparse encoder-encoder decoder attention additional_dispatch_params: The extra tensors that need to be sent to each expert. Examples include batch batch coordinates (see common_attention.local_expert_attention) name: a string Returns: y: a tensor. Has the same shape as x, except for the last dimension, which is output_size. extra_training_loss: a scalar. This should be added into the overall training loss of the model. The backpropagation of this loss encourages all experts to be approximately equally used across a batch. """ with tf.variable_scope(name, default_name="local_moe"): centroids = None x_flat = flatten_all_but_last(x) if True: tf.logging.info("Using noisy top_k with k = {}".format(k)) # The gates indicate which batch elements go to which tensors. # load is a measure of approximately how many examples go to each expert gates, load = noisy_top_k_gating( x_flat, num_experts, train, k, initializer=tf.zeros_initializer(), noisy_gating=True, noise_epsilon=1e-2) importance = tf.reduce_sum(gates, 0) loss = (cv_squared(importance) + cv_squared(load)) loss *= loss_coef # Shuffle data between datashards and experts. dispatcher = SparseDispatcher(num_experts, gates) # Set up expert_fn arguments expert_kwargs = {} if pass_x: expert_kwargs["x"] = dispatcher.dispatch(x_flat) if pass_gates: expert_kwargs["gates"] = dispatcher.expert_to_gates() for key, val in six.iteritems(additional_dispatch_params or {}): val = flatten_all_but_last(val) expert_kwargs[key] = dispatcher.dispatch(val) ep = Parallelism([DEFAULT_DEV_STRING] * num_experts, reuse=None) expert_outputs = ep(expert_fn, **expert_kwargs) y_flat = dispatcher.combine(expert_outputs) if centroids is not None: centroids = tf.squeeze(centroids, axis=[1, 2]) y_flat += centroids y = common_layers.reshape_like(y_flat, x) return y, loss