Exemplo n.º 1
0
def restore_pad(x, ref_x, pad_remover, mode):
  x = tf.squeeze(x, axis=0)
  if mode != ModeKeys.PREDICT:
    x = pad_remover.restore(x)
  x = common_layers.reshape_like(x, ref_x)
  return x
Exemplo n.º 2
0
def restore_pad(x, ref_x, pad_remover, mode):
    x = tf.squeeze(x, axis=0)
    if mode != ModeKeys.PREDICT:
        x = pad_remover.restore(x)
    x = common_layers.reshape_like(x, ref_x)
    return x
Exemplo n.º 3
0
def local_moe(x,
              train,
              expert_fn,
              num_experts,
              k=1,
              loss_coef=1e-2,
              hparams=None,
              pass_x=True,
              pass_gates=False,
              additional_dispatch_params=None,
              name=None):
    """Call a local mixture of experts.
  Args:
    x: a tensors with shape [... , input_size]
    train: a boolean scalar.
    expert_fn: a function.
    num_experts: an integer - number of experts
    k: an integer - how many experts to use for each batch element
    loss_coef: a scalar - multiplier on load-balancing losses
    hparams: optional hparams for vq gating
    pass_x: a boolean. If true, x will also be dispatched to the experts.
    pass_gates: a boolean. If true, gates will be passed to experts. Might be
      necessary when dealing with sparse encoder-encoder decoder attention
    additional_dispatch_params: The extra tensors that need to be sent to each
      expert. Examples include batch batch coordinates (see
      common_attention.local_expert_attention)
    name: a string
  Returns:
    y: a tensor.  Has the same shape as x, except for the last dimension,
      which is output_size.
    extra_training_loss: a scalar.  This should be added into the overall
      training loss of the model.  The backpropagation of this loss
      encourages all experts to be approximately equally used across a batch.
  """
    with tf.variable_scope(name, default_name="local_moe"):
        centroids = None
        x_flat = flatten_all_but_last(x)
        if True:
            tf.logging.info("Using noisy top_k with k = {}".format(k))
            # The gates indicate which batch elements go to which tensors.
            # load is a measure of approximately how many examples go to each expert
            gates, load = noisy_top_k_gating(
                x_flat,
                num_experts,
                train,
                k,
                initializer=tf.zeros_initializer(),
                noisy_gating=True,
                noise_epsilon=1e-2)
            importance = tf.reduce_sum(gates, 0)
            loss = (cv_squared(importance) + cv_squared(load))
        loss *= loss_coef
        # Shuffle data between datashards and experts.
        dispatcher = SparseDispatcher(num_experts, gates)
        # Set up expert_fn arguments
        expert_kwargs = {}
        if pass_x:
            expert_kwargs["x"] = dispatcher.dispatch(x_flat)
        if pass_gates:
            expert_kwargs["gates"] = dispatcher.expert_to_gates()
        for key, val in six.iteritems(additional_dispatch_params or {}):
            val = flatten_all_but_last(val)
            expert_kwargs[key] = dispatcher.dispatch(val)

        ep = Parallelism([DEFAULT_DEV_STRING] * num_experts, reuse=None)
        expert_outputs = ep(expert_fn, **expert_kwargs)

        y_flat = dispatcher.combine(expert_outputs)
        if centroids is not None:
            centroids = tf.squeeze(centroids, axis=[1, 2])
            y_flat += centroids
        y = common_layers.reshape_like(y_flat, x)
        return y, loss