예제 #1
0
 def _rearrange_sentinels(self, logits):
   """Reorder along the vocab dim so the last few tokens don't share gates."""
   if not self._extra_ids:
     return logits
   sentinels, nonsentinels = mtf.split(
       logits, self._vocab_dim,
       [self._extra_ids, self._vocab_dim.size - self._extra_ids])
   return mtf.concat([nonsentinels, sentinels], self._vocab_dim.name)
예제 #2
0
 def _sigmoid_tree(self, tensor):
   """Create probability distribution along gates dim using a sigmoid tree."""
   gamma = mtf.split(
       mtf.sigmoid(tensor), self._pre_gates_dim, self._pre_gates_dim.size)
   return mtf.concat([
       gamma[0] * gamma[1],
       gamma[0] * (1 - gamma[1]),
       (1 - gamma[0]) * gamma[2],
       (1 - gamma[0]) * (1 - gamma[2]),
   ], self._gates_dim.name)
예제 #3
0
def mlp_glu(x, scope, n_state, *, variable_dtype, params):
    with tf.variable_scope(scope):
        nx = x.shape[-1]
        h = linear(x, "c_fc", n_state, params=params)

        h, gate = mtf.split(h, h.shape[-1], 2)
        h *= mtf.gelu(gate)

        h2 = linear(h, "c_proj", nx, variable_dtype=variable_dtype, params=params, scale=True)
        if params["mode"] == "train" and params["res_dropout"] > 0:
            h2 = mtf.dropout(h2, rate=params["res_dropout"], name="mlp_dropout")
        return h2
예제 #4
0
    def _get_decoder_inputs(self, context):
        """Computes the inputs to the decoder when using transparent attention.

    We must cache on the context in order to ensure that we are not replicating
    variables when the layer's call function is called in different tf variable
    scopes.

    Args:
      context: a Context

    Returns:
      a list containing `self.num_decoder_modules` of tensors with shape
        [<batch_dims>, length_dim, output_vocab_dim]
    """
        if hasattr(context, "decoder_layers_per_module"):
            return context.decoder_layers_per_module

        encoder_layer_outputs = [
            mtf.layers.rename_length_to_memory_length(output)
            for output in context.encoder_layer_outputs
        ]

        layers_per_module = self.layers_per_encoder_module
        encoder_module_outputs_dim = mtf.Dimension(
            "encoder_module_outputs", size=self.encoder_num_modules + 1)
        decoder_module_inputs_dim = mtf.Dimension(
            "decoder_module_inputs", size=self.decoder_num_modules)
        encoder_module_outputs = mtf.stack(
            [encoder_layer_outputs[0]] +
            encoder_layer_outputs[layers_per_module::layers_per_module],
            dim_name="encoder_module_outputs")
        w = mtf.get_variable(
            context.mesh,
            "w",
            mtf.Shape([encoder_module_outputs_dim, decoder_module_inputs_dim]),
            initializer=tf.random_normal_initializer(
                stddev=(encoder_module_outputs_dim.size *
                        decoder_module_inputs_dim.size)**-0.5),
            dtype=context.variable_dtype)
        if context.train and self.dropout_rate != 0.0:
            w = mtf.dropout(w, 1.0 - self.dropout_rate)
        s = mtf.softmax(w, reduced_dim=encoder_module_outputs_dim)
        z = mtf.einsum([s, encoder_module_outputs],
                       reduced_dims=[encoder_module_outputs_dim])
        input_per_decoder = mtf.split(
            z,
            split_dim=decoder_module_inputs_dim,
            num_or_size_splits=decoder_module_inputs_dim.size)
        context.decoder_layers_per_module = [
            mtf.reshape(inpt, z.shape.dims[1:]) for inpt in input_per_decoder
        ]
        return context.decoder_layers_per_module
예제 #5
0
def model_fn(features, labels, mode, params):
    """A model is called by TpuEstimator."""
    del labels
    del features

    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

    ctx = params['context']
    num_hosts = ctx.num_hosts
    host_placement_fn = ctx.tpu_host_placement_function
    device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
    tf.logging.info('device_list = %s' % device_list, )

    mesh_devices = [''] * mesh_shape.size
    mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
                                                mesh_devices,
                                                ctx.device_assignment)

    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "fft_mesh")

    with mtf.utils.outside_all_rewrites():
        field = nbody_model(mesh)
        batch_dim, x_dim, y_dim, z_dim = field.shape
        x_dim_nosplit = mtf.Dimension("nx_nosplit", FLAGS.cube_size)
        y_dim_nosplit = mtf.Dimension("ny_nosplit", FLAGS.cube_size)

        # Until we implement distributed outputs, we only return one example
        field_slice, _ = mtf.split(field, batch_dim, [1, FLAGS.batch_size - 1])
        field_slice = mtf.reshape(
            field_slice,
            [mtf.Dimension("bs", 1), x_dim_nosplit, y_dim_nosplit, z_dim])
        #field_slice = field

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    tf_field = tf.to_float(lowering.export_to_tf_tensor(field_slice))

    with mtf.utils.outside_all_rewrites():
        return tpu_estimator.TPUEstimatorSpec(mode,
                                              predictions={'field': tf_field})
def attention(x, dim_head, dim_features_head, scope='attn', causal=False):
    with tf.variable_scope(scope):
        mesh, batch, seq, dim = x.mesh, *x.shape

        dim_heads = mtf.Dimension('dim_heads',
                                  dim_head.size * dim_features_head.size)
        dim_intermediate = mtf.Dimension('qkv_dimension', dim_heads.size * 3)
        qkv = linear(x, dim_intermediate, bias=False, scope='to_qkv')

        q, k, v = mtf.split(qkv, dim_intermediate, 3)
        q, k, v = map(
            lambda t: mtf.reshape(t, [batch, seq, dim_head, dim_features_head]
                                  ), (q, k, v))
        q, k, v = map(
            lambda t: mtf.transpose(
                t, [batch, dim_head, seq, dim_features_head]), (q, k, v))

        k, v = map(
            lambda t: mtf.rename_dimension(t, seq.name, 'memory_length'),
            (k, v))
        mem_len_dim = v.shape[-2]

        dots = mtf.layers.us_einsum([q, k],
                                    [batch, dim_head, seq, mem_len_dim])

        if causal:
            i = mtf.range(mesh, seq, tf.int32)
            j = mtf.range(mesh, mem_len_dim, tf.int32)
            i, j = map(lambda t: mtf.broadcast(t, [seq, mem_len_dim]), (i, j))
            mask = mtf.less(i + mem_len_dim.size - seq.size, j)
            mask = mtf.cast(mask, tf.float32) * -1e10
            dots += mask

        attn = mtf.softmax(dots, mem_len_dim)
        out = mtf.einsum([attn, v], [batch, dim_head, seq, dim_features_head])

        out = mtf.transpose(out, [batch, seq, dim_head, dim_features_head])
        out = mtf.reshape(out, [batch, seq, dim_heads])

        combined_out = linear(out, dim, scope='combine_output')
        return combined_out
예제 #7
0
def _rand_1_gating(
    inputs, outer_expert_dims, experts_dim, expert_capacity_dim,
    hparams, train, variable_dtype, importance=None, name="rand_1_gating",
    num_microbatches=None):
  """Compute a random top-1 gating."""
  # SELECT EXPERT
  if train:
    policy = hparams.moe_rand_1_policy_train
  else:
    policy = hparams.moe_rand_1_policy_eval

  # The internals of this function run in float32.
  #   bfloat16 seems to reduce quality.
  gate_inputs = mtf.to_float(inputs)

  # Input perturbations
  if train and policy == "input_dropout":
    gate_inputs = mtf.dropout(gate_inputs, 1.0 - hparams.moe_rand_1_dropout)
  elif train and policy == "input_jitter":
    gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs,
                                                   hparams.moe_rand_1_jitter)

  gate_logits = mtf.layers.dense(
      gate_inputs,
      experts_dim,
      use_bias=False,
      expert_dims=outer_expert_dims,
      variable_dtype=variable_dtype,
      name=name)
  raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim)

  if policy == "argmax" or policy == "input_dropout" or policy == "input_jitter":
    expert_gate, expert_index = mtf.top_1(raw_gates, reduced_dim=experts_dim)
  elif policy == "sample":
    expert_index = mtf.sample_with_temperature(
        gate_logits, experts_dim, temperature=hparams.moe_rand_1_temperature)
    expert_gate = mtf.gather(raw_gates, expert_index, dim=experts_dim)
  else:
    raise ValueError("Unknown rand_1 policy %s" % policy)

  expert_mask = mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype)

  # LOAD BALANCING LOSS
  # TODO(liamfedus): Check entropy loss.
  group_size_dim = inputs.shape[-2]
  density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim)
  density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim)
  if importance is not None:
    expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    density_1_proxy *= mtf.cast(
        mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
  loss = (
      mtf.reduce_mean(density_1_proxy * density_1) *
      float(experts_dim.size * experts_dim.size))
  if num_microbatches and num_microbatches > 1:
    tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
        num_microbatches))
    loss /= num_microbatches

  # Logging
  if train:
    entropy = mtf.reduce_sum(-raw_gates * mtf.log(raw_gates + 1e-9),
                             reduced_dim=experts_dim)
    batch_entropy = mtf.reduce_mean(entropy)
    mtf.scalar_summary(name + "/entropy", batch_entropy)

    mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim])
    total_routed = mtf.reduce_sum(mask_count_experts)
    expert_fraction = mtf.to_float(mask_count_experts / total_routed)
    split_fractions = mtf.split(
        expert_fraction,
        split_dim=experts_dim,
        num_or_size_splits=experts_dim.size)
    for fraction in split_fractions:
      mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"),
                         mtf.reduce_mean(fraction))
    mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss))

  # COMPUTE ASSIGNMENT TO EXPERT
  # Experts have a limited capacity, ensure we do not exceed it. Construct
  # the batch indices, to each expert, with position_in_expert
  position_in_expert = mtf.cumsum(
      expert_mask, group_size_dim, exclusive=True) * expert_mask
  position_in_expert = mtf.cast(position_in_expert, dtype=raw_gates.dtype)
  # Keep only tokens that fit within expert_capacity.
  expert_capacity_float = float(expert_capacity_dim.size)
  expert_mask *= mtf.cast(
      mtf.less(position_in_expert, expert_capacity_float),
      dtype=raw_gates.dtype)
  expert_mask_flat = mtf.reduce_sum(expert_mask, reduced_dim=experts_dim)

  # Mask out the experts that have overflowed expert capacity. Sparsify the
  # expert_gate.
  expert_gate *= expert_mask_flat

  combine_tensor = (
      expert_gate * expert_mask_flat *
      mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype) *
      mtf.one_hot(
          mtf.to_int32(position_in_expert),
          expert_capacity_dim,
          dtype=raw_gates.dtype))

  # Match the inputs dtype.
  combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
  loss = mtf.cast(loss, inputs.dtype)
  dispatch_tensor = mtf.cast(
      mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype)

  return dispatch_tensor, combine_tensor, loss
예제 #8
0
def _switch_gating(inputs,
                   outer_expert_dims,
                   experts_dim,
                   expert_capacity_dim,
                   hparams,
                   train,
                   variable_dtype,
                   importance=None,
                   name="switch_gating",
                   num_microbatches=None):
  """Compute a switch top-1 gating with no-token-left behind behavior."""
  # SELECT EXPERT
  if train:
    policy = hparams.moe_rand_1_policy_train
  else:
    policy = hparams.moe_rand_1_policy_eval

  # Input perturbations
  if train and policy == "input_jitter":
    inputs = mtf.layers.multiplicative_jitter(inputs, hparams.moe_rand_1_jitter)

  gate_logits = mtf.layers.dense(
      inputs,
      experts_dim,
      use_bias=False,
      expert_dims=outer_expert_dims,
      variable_dtype=variable_dtype,
      name=name)
  raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim)

  # The internals of this function run in float32.
  #   bfloat16 seems to reduce quality.
  raw_gates = mtf.to_float(raw_gates)

  # Top-k operation
  k_dim = mtf.Dimension("k", hparams.moe_switch_top_k)
  expert_gate, expert_index = mtf.top_k(
      raw_gates, reduced_dim=experts_dim, k_dim=k_dim)
  expert_mask = mtf.one_hot(expert_index, experts_dim)

  # LOAD BALANCING LOSS
  outer_batch_dim = inputs.shape[0]
  batch_dim = inputs.shape[1]
  group_size_dim = inputs.shape[-2]
  density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim)
  density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim)
  if importance is not None:
    expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    density_1_proxy *= mtf.cast(
        mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
  loss = (
      mtf.reduce_mean(density_1_proxy * density_1) *
      float(experts_dim.size * experts_dim.size))
  if num_microbatches and num_microbatches > 1:
    tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
        num_microbatches))
    loss /= num_microbatches

  # Logging
  if train:
    entropy = mtf.reduce_sum(
        -raw_gates * mtf.log(raw_gates + 1e-9), reduced_dim=experts_dim)
    batch_entropy = mtf.reduce_mean(entropy)
    mtf.scalar_summary(name + "/entropy", batch_entropy)

    mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim])
    total_routed = mtf.reduce_sum(mask_count_experts)
    expert_fraction = mtf.to_float(mask_count_experts / total_routed)
    split_fractions = mtf.split(
        expert_fraction,
        split_dim=experts_dim,
        num_or_size_splits=experts_dim.size)
    for fraction in split_fractions:
      mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"),
                         mtf.reduce_mean(fraction))
    mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss))

  # COMPUTE ASSIGNMENT TO EXPERT
  # Iteratively route tokens (no-token-left-behind). The idea is to route as
  # many tokens as possible to top-i before then trying top-(i+1).
  top_k_masks = mtf.split(
      expert_mask, split_dim=k_dim, num_or_size_splits=k_dim.size)
  top_k_gates = mtf.split(
      expert_gate, split_dim=k_dim, num_or_size_splits=k_dim.size)
  top_k_indices = mtf.split(
      expert_index, split_dim=k_dim, num_or_size_splits=k_dim.size)

  # Tensors cumulative values over the iterative process.
  combine_tensor = mtf.constant(
      inputs.mesh,
      value=0,
      shape=[outer_batch_dim, batch_dim, experts_dim, expert_capacity_dim])
  cum_tokens = mtf.constant(
      inputs.mesh, value=0, shape=[outer_batch_dim, batch_dim, experts_dim])
  tokens_left_to_route = mtf.constant(
      inputs.mesh, value=1., shape=[outer_batch_dim, batch_dim, group_size_dim])

  expert_capacity_float = float(expert_capacity_dim.size)
  for (top_i_mask, top_i_gate, top_i_index) in zip(top_k_masks, top_k_gates,
                                                   top_k_indices):
    top_i_mask = mtf.reshape(
        top_i_mask,
        new_shape=[outer_batch_dim, batch_dim, group_size_dim, experts_dim])
    # Operate only on the unrouted tokens.
    top_i_mask *= tokens_left_to_route

    # Record cumulative number of tokens to each expert across iterations.
    cumulative_tokens_in_expert = cum_tokens + mtf.cumsum(
        top_i_mask, group_size_dim)

    expert_overflow = mtf.to_float(
        mtf.less_equal(cumulative_tokens_in_expert, expert_capacity_float))
    output_i_tokens = top_i_mask * expert_overflow

    # Update the cumulative tokens routed to each expert.
    cum_tokens += mtf.reduce_sum(output_i_tokens, reduced_dim=group_size_dim)
    tokens_left_to_route -= (
        mtf.reduce_sum(output_i_tokens, reduced_dim=experts_dim))

    # Combine-tensor for this iteration
    output_i_tokens_flat = mtf.reduce_sum(
        output_i_tokens, reduced_dim=experts_dim)
    position_in_expert = cumulative_tokens_in_expert - 1
    top_i_combine_tensor = (
        top_i_gate * output_i_tokens_flat *
        mtf.one_hot(top_i_index, experts_dim) *
        mtf.one_hot(mtf.to_int32(position_in_expert), expert_capacity_dim))
    combine_tensor += top_i_combine_tensor

  # Match the inputs dtype.
  combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
  loss = mtf.cast(loss, inputs.dtype)
  dispatch_tensor = mtf.cast(
      mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype)

  return dispatch_tensor, combine_tensor, loss