Example #1
0
    def _mtf_model_fn(self, features, mesh):
        self._original_features = features
        hparams = self._hparams

        def import_feature(key):
            return self._import_feature(features, mesh, key)

        targets = import_feature("targets")
        sequence_id = import_feature("targets_segmentation")
        if hparams.use_global_position_in_packed_sequence:
            position = None
        else:
            position = import_feature("targets_position")
        if self.autoregressive:
            inputs = mtf.shift(targets,
                               offset=1,
                               dim=self.length_dim,
                               wrap=False)
            # We should have a 0 at the beginning of each sequence rather than the
            # shifted EOS (1) from the previous sequence.
            inputs -= mtf.to_int32(mtf.equal(inputs, 1))
        else:
            inputs = import_feature("inputs")
            # TODO(noam): options for bert-style masking here?
        model = self.model()
        logits, loss = model.call_simple(inputs=inputs,
                                         targets=targets,
                                         compute_loss=True,
                                         mode=hparams.mode,
                                         variable_dtype=self.variable_dtype,
                                         sequence_id=sequence_id,
                                         position=position)
        return logits, loss
Example #2
0
 def call(self, context, x, losses=None):
     """Call the layer."""
     m = self._get_memory_antecedent(context)
     memory_input_dim = m.shape[-1]
     if memory_input_dim != context.model.model_dim:
         raise NotImplementedError(
             "TODO(noam): support different model_dim in encoder and decoder."
         )
     q = self.compute_q(context, x)
     if context.mode == "incremental":
         k, v, memory_length = context.get_constant_state()
     else:
         k = self.compute_k(context, m)
         v = self.compute_v(context, m)
         memory_length, = [
             d for d in m.shape.dims if d.name == "memory_length"
         ]
         if context.mode == "first_part":
             context.record_constant_state((k, v, memory_length))
     if context.encoder_sequence_id and context.sequence_id:
         visible = mtf.equal(context.sequence_id,
                             context.encoder_sequence_id)
         bias = attention.visibility_mask_to_attention_bias(
             visible, context.activation_dtype)
     else:
         bias = None
     return self.attention_internal(context, x, m, q, k, v, memory_length,
                                    bias)
Example #3
0
 def cond_fn(position, ids, *unused_states):
   """Should we run another loop iteration."""
   past_end = mtf.greater_equal(position, length_dim.size)
   is_done = past_end
   if stop_at_token is not None:
     has_eos = mtf.reduce_any(
         mtf.equal(ids, stop_at_token), reduced_dim=length_dim)
     is_done = mtf.logical_or(is_done, has_eos)
   all_done = mtf.reduce_all(is_done)
   return mtf.logical_not(all_done)
Example #4
0
    def cond_fn(position, ids, *unused_states):
        """Should we run another loop iteration?"""
        past_end = mtf.greater_equal(position, length_dim.size)
        if max_steps:
            past_end = mtf.logical_or(
                past_end, mtf.greater_equal(position - initial_position, max_steps))

        is_done = past_end
        if stop_at_token is not None:
            eos_count = mtf.reduce_sum(
                mtf.to_int32(mtf.equal(ids, stop_at_token)),
                reduced_dim=length_dim)
            has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count)
            is_done = mtf.logical_or(is_done, has_additional_eos)
        all_done = mtf.reduce_all(is_done)
        return mtf.logical_not(all_done)
Example #5
0
 def call(self, context, x, losses=None):
     """Call the layer."""
     params = mtf.layers.multihead_attention_params(context.mesh,
                                                    self.heads_dim,
                                                    context.model_dim,
                                                    self.kv_dim,
                                                    context.variable_dtype)
     if context.mode == "incremental":
         prev_k, prev_v = context.get_states(2)
         y, new_k, new_v = mtf.layers.masked_local_attention_1d_incremental(
             x, prev_k, prev_v, context.position, params=params)
         context.record_new_states([new_k, new_v])
         return y
     else:
         kv = []
         y = mtf.layers.masked_local_attention_1d(x,
                                                  self.kv_dim,
                                                  self.heads_dim,
                                                  self.window_size,
                                                  params=params,
                                                  return_kv=kv)
         if context.mode == "first_part":
             k = kv[0]
             v = kv[1]
             window_dim = mtf.Dimension("window", self.window_size)
             mesh = k.mesh
             window_pos = mtf.range(mesh, window_dim, tf.int32)
             pos = mtf.range(mesh, context.length_dim, tf.int32)
             select_recent = mtf.cast(
                 mtf.equal(window_pos, mtf.mod(pos, self.window_size)),
                 k.dtype)
             select_recent *= mtf.cast(
                 mtf.less(pos, context.initial_position), k.dtype)
             select_recent *= mtf.cast(
                 mtf.greater_equal(
                     pos, context.initial_position - self.window_size),
                 k.dtype)
             state_shape = k.shape.dims[:-2] + [window_dim, self.kv_dim]
             k_state = mtf.einsum([k, select_recent],
                                  output_shape=state_shape,
                                  reduced_dims=[context.length_dim])
             v_state = mtf.einsum([v, select_recent],
                                  output_shape=state_shape,
                                  reduced_dims=[context.length_dim])
             context.new_states.extend([k_state, v_state])
         return y
def shift_targets_no_offset(targets, bos_id=0, eos_id=1):
    """Transforms decoder labels to decoder inputs.
  Args:
    targets: decoder labels
    bos_id: begin of sequence id, defaults to 0
    eos_id: end of sequence id, defaults to 1
  Returns:
    Decoder inputs.
  """
    length_dim = targets.shape.dims[-1]
    shifted_targets = targets
    # We should have a 0 at the beginning of each sequence rather than the
    # shifted EOS (e.g. 1) from the previous sequence.
    shifted_targets *= mtf.to_int32(mtf.not_equal(shifted_targets, eos_id))

    if bos_id:
        shifted_targets += mtf.to_int32(
            mtf.logical_and(mtf.equal(shifted_targets, 0),
                            mtf.not_equal(targets, 0))) * bos_id

    return shifted_targets
Example #7
0
def enc_dec_attention(self_attention_layer, memory_antecedent, context, x,
                      losses):
  """Multi-head attention over the encoder outputs."""
  memory_input_dim = memory_antecedent.shape[-1]
  if memory_input_dim != context.model.model_dim:
    raise NotImplementedError(
        "TODO(noam): support different model_dim in encoder and decoder.")
  params = self_attention_layer.make_params(context)
  q = params.compute_q(x)
  if context.mode == "incremental":
    k, v, memory_length = context.get_constant_state()
  else:
    m = memory_antecedent
    if self_attention_layer.shared_kv:
      kv = params.compute_kv(m)
      k = kv
      v = kv
    else:
      k = params.compute_k(m)
      v = params.compute_v(m)
    memory_length, = [d for d in m.shape.dims if d.name == "memory_length"]
    if context.mode == "first_part":
      context.record_constant_state((k, v, memory_length))
  if context.encoder_sequence_id and context.sequence_id:
    visible = mtf.equal(context.sequence_id, context.encoder_sequence_id)
    bias = attention.visibility_mask_to_attention_bias(visible,
                                                       context.activation_dtype)
  else:
    bias = None
  a = attention.attention(
      q, k, v, memory_length, self_attention_layer.kv_dim,
      self_attention_layer.kv_dim, bias,
      **self_attention_layer.attention_kwargs_from_context(context))
  attention_output_shape = self_attention_layer.expected_attention_output_shape(
      x, params)
  attention_output = params.compute_output(
      a, output_shape=attention_output_shape)
  return self_attention_layer.layer_output_from_attention_output(
      context, attention_output, losses)
Example #8
0
 def call(self, context, x, losses=None):
     """Call the layer."""
     memory_antecedent = self._get_memory_antecedent(context)
     memory_input_dim = memory_antecedent.shape[-1]
     if memory_input_dim != context.model.model_dim:
         raise NotImplementedError(
             "TODO(noam): support different model_dim in encoder and decoder."
         )
     params = self.make_params(context)
     q = params.compute_q(x)
     if context.mode == "incremental":
         k, v, memory_length = context.get_constant_state()
     else:
         m = memory_antecedent
         if self.shared_kv:
             kv = params.compute_kv(m)
             k = kv
             v = kv
         else:
             k = params.compute_k(m)
             v = params.compute_v(m)
         memory_length, = [
             d for d in m.shape.dims if d.name == "memory_length"
         ]
         if context.mode == "first_part":
             context.record_constant_state((k, v, memory_length))
     if context.encoder_sequence_id and context.sequence_id:
         visible = mtf.equal(context.sequence_id,
                             context.encoder_sequence_id)
         bias = attention.visibility_mask_to_attention_bias(
             visible, context.activation_dtype)
     else:
         bias = None
     o = attention.attention(q, k, v, memory_length, self.kv_dim,
                             self.kv_dim, bias,
                             **self.attention_kwargs_from_context(context))
     return params.compute_output(o, output_shape=x.shape)
Example #9
0
  def call_simple(
      self,
      inputs,
      targets,
      compute_loss,
      mode=tf.estimator.ModeKeys.TRAIN,
      variable_dtype=mtf.VariableDType(tf.float32),
      encoder_sequence_id=None,
      decoder_sequence_id=None,
      encoder_position=None,
      decoder_position=None):
    """Compute logits based on inputs (all positions in parallel).

    This is called during training and evaluation.

    Args:
      inputs: an int32 Tensor with shape [<batch_dims>, length_dim]
      targets: an optional int32 Tensor with shape [<batch_dims>, length_dim]
      compute_loss: a boolean
      mode: a tf.estimator.ModeKeys
      variable_dtype: a mtf.VariableDType
      encoder_sequence_id: an optional Tensor
      decoder_sequence_id: an optional Tensor
      encoder_position: an optional Tensor
      decoder_position: an optional Tensor

    Returns:
      logits: a Tensor with shape [<batch_dims>, output_vocab_dim]
      loss: an optional Scalar (if compute_loss=True)
    """
    encoder_layer_outputs = []
    shared_params = self._shared_params(inputs.mesh, variable_dtype)
    encoder_output, encoder_loss = self.encoder.call_simple(
        inputs,
        None,
        compute_loss,
        mode=mode,
        variable_dtype=variable_dtype,
        sequence_id=encoder_sequence_id,
        position=encoder_position,
        shared_params=shared_params,
        layer_outputs=encoder_layer_outputs)
    encoder_output = mtf.layers.rename_length_to_memory_length(encoder_output)
    if encoder_sequence_id is not None:
      encoder_sequence_id = mtf.layers.rename_length_to_memory_length(
          encoder_sequence_id)
    length_dim = targets.shape.dims[-1]
    shifted_targets = mtf.shift(targets, offset=1, dim=length_dim, wrap=False)
    # We should have a 0 at the beginning of each sequence rather than the
    # shifted EOS (1) from the previous sequence.
    shifted_targets -= mtf.to_int32(mtf.equal(shifted_targets, 1))
    logits, loss = self.decoder.call_simple(
        shifted_targets,
        targets,
        compute_loss,
        mode=mode,
        variable_dtype=variable_dtype,
        sequence_id=decoder_sequence_id,
        encoder_output=encoder_output,
        encoder_sequence_id=encoder_sequence_id,
        position=decoder_position,
        shared_params=shared_params,
        encoder_layer_outputs=encoder_layer_outputs)
    if loss is not None and encoder_loss is not None:
      loss += encoder_loss
    return logits, loss
    def sample_autoregressive(self,
                              partial_sequences,
                              dst_attributes=None,
                              stop_at_token=1,
                              max_steps=None,
                              temperature=0.0,
                              variable_dtype=mtf.VariableDType(tf.float32),
                              encoder_output=None,
                              encoder_sequence_id=None,
                              encoder_inputs=None,
                              shared_params=None,
                              has_partial_sequences=True,
                              encoder_layer_outputs=None,
                              never_end=False,
                              remove_partial_sequences=False,
                              sampling_keep_top_k=-1,
                              z=None):
        """Sample randomly one token at a time.
        The partial_sequences represent partial sequences to be continued.  The
        first tokens of each sequence are nonzero representing the given partial
        sequences and the last tokens of each sequence are zeros, representing what
        needs to be filled in.
        If there are no partial sequences (you want to sample from the beginning),
        then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and
        has_partial_sequences=False (so we can skip computation).
        The dst_attributes represents the destination attributes in which we want to generate sequences.
        Args:
          partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim]
          dst_attribute: an int32 Tensor with shape [<batch_dims>, length_dim] ([<batch_dims>])
          stop_at_token: an optional integer eos id.  Stop when we produce it.
          max_steps: an optional integer, the max number of steps to decode.
          temperature: an optional floating point value between 0.0 and 1.0 0.0
            means argmax, 1.0 means sample according to predicted distribution.
          variable_dtype: a mtf.VariableDType
          encoder_output: an optional Tensor
          encoder_sequence_id: an optional Tensor
          encoder_inputs: an optional Tensor
          shared_params: an optional dictionary
          has_partial_sequences: a boolean
          encoder_layer_outputs: optional - readonly list of tensor activations when
            decoding, one per each input layer + the embedding layer
          never_end: a boolean - if set, then avoid generating stop_at_token
          remove_partial_sequences: a boolean - whether to remove the partial
            sequences from the output
          sampling_keep_top_k: an integer - if not -1, only sample from the top k
            logits.
        Returns:
          a Tensor with shape [<batch_dims>, length_dim]
        """
        if not self.autoregressive:
            raise ValueError("must be autoregressive")

        inputs = partial_sequences
        attributes = dst_attributes
        batch_dims = inputs.shape.dims[:-1]
        length_dim = inputs.shape.dims[-1]
        initial_position = mtf.reduce_sum(mtf.to_int32(mtf.not_equal(
            inputs, 0)),
                                          reduced_dim=length_dim)
        sequence_id = 1 if encoder_sequence_id is not None else None

        length_range = mtf.range(inputs.mesh, length_dim, tf.int32)
        if self.input_full_attention:
            read_priority = write_priority = length_range * mtf.to_int32(
                mtf.greater(length_range, initial_position))
        else:
            read_priority = write_priority = length_range

        context_first_part = Context(
            model=self,
            mesh=inputs.mesh,
            batch_dims=batch_dims,
            length_dim=length_dim,
            variable_dtype=variable_dtype,
            mode="first_part",
            position=length_range,
            position_is_default=True,
            new_states=[],
            initial_position=initial_position,
            sequence_id=sequence_id,
            encoder_output=encoder_output,
            encoder_sequence_id=encoder_sequence_id,
            constant_states=[],
            shared_params=shared_params,
            encoder_layer_outputs=encoder_layer_outputs,
            write_priority=write_priority,
            read_priority=read_priority,
            inputs=inputs,
            encoder_inputs=encoder_inputs)

        shifted_inputs = mtf.shift(inputs,
                                   offset=1,
                                   dim=length_dim,
                                   wrap=False)
        with tf.variable_scope(self.name):
            logits = self._call_internal(context_first_part,
                                         shifted_inputs,
                                         attributes=attributes,
                                         z=z)
        del logits
        constant_states = context_first_part.constant_states
        if not has_partial_sequences:
            initial_states = [
                mtf.zeros_like(t) for t in context_first_part.new_states
            ]
            partial_sequences_eos_count = 0
        else:
            initial_states = context_first_part.new_states
            partial_sequences_eos_count = mtf.reduce_sum(
                mtf.to_int32(mtf.equal(partial_sequences, stop_at_token)),
                reduced_dim=length_dim)

        def cond_fn(position, ids, *unused_states):
            """Should we run another loop iteration."""
            past_end = mtf.greater_equal(position, length_dim.size)
            if max_steps:
                past_end = mtf.logical_or(
                    past_end,
                    mtf.greater_equal(position - initial_position, max_steps))

            is_done = past_end
            if stop_at_token is not None:
                eos_count = mtf.reduce_sum(mtf.to_int32(
                    mtf.equal(ids, stop_at_token)),
                                           reduced_dim=length_dim)
                has_additional_eos = mtf.greater(eos_count,
                                                 partial_sequences_eos_count)
                is_done = mtf.logical_or(is_done, has_additional_eos)
            all_done = mtf.reduce_all(is_done)
            return mtf.logical_not(all_done)

        def body_fn(position, ids, *states):
            """One step in the decode loop."""
            inputs_this_step = mtf.gather(ids, position - 1, length_dim)
            if self.attribute_embedding:
                attributes_this_step = mtf.gather(attributes, position - 1,
                                                  length_dim)
            else:
                attributes_this_step = None
            # raise ValueError("inputs_this_step shape=%s , ids shape=%s, position - 1 shape=%s, length_dim=%s" % (inputs_this_step.shape, ids.shape, (position - 1).shape, length_dim))
            context_incremental = Context(
                model=self,
                mesh=inputs.mesh,
                batch_dims=batch_dims,
                length_dim=length_dim,
                variable_dtype=variable_dtype,
                mode="incremental",
                position=position,
                states=states,
                new_states=[],
                sequence_id=sequence_id,
                encoder_output=encoder_output,
                encoder_sequence_id=encoder_sequence_id,
                constant_states=constant_states,
                shared_params=shared_params,
                encoder_layer_outputs=encoder_layer_outputs,
                write_priority=write_priority,
                read_priority=position,
                inputs=inputs_this_step,
                encoder_inputs=encoder_inputs)

            with tf.variable_scope(self.name, reuse=True):
                logits = self._call_internal(context_incremental,
                                             inputs_this_step,
                                             attributes=attributes_this_step,
                                             z=z)
                if never_end:
                    logits += mtf.one_hot(mtf.constant(logits.mesh,
                                                       stop_at_token,
                                                       dtype=tf.int32),
                                          self.output_vocab_dim,
                                          on_value=-1e9,
                                          off_value=0.0,
                                          dtype=logits.dtype)

            # TBD whether this should be before or after never_end:
            # Note for adding top_p sampling in the future, in other code bases, the
            # option to apply temperature is done before the top-k truncation. This
            # implementation does this in the opposite order. For top-k this doesn't
            # matter, but for top_p it will.
            if sampling_keep_top_k != -1:
                if sampling_keep_top_k <= 0:
                    raise ValueError(
                        "sampling_keep_top_k must either be -1 or positive.")
                k_largest = mtf.nth_largest_element(
                    logits,
                    n=sampling_keep_top_k,
                    reduced_dim=self.output_vocab_dim)
                logits = mtf.where(mtf.less_equal(logits, k_largest),
                                   mtf.ones_like(logits) * -1e6, logits)

            ids_this_step = mtf.sample_with_temperature(
                logits, self.output_vocab_dim, temperature)
            new_position = position + 1
            new_ids = ids + ids_this_step * mtf.one_hot(
                position, length_dim, dtype=tf.int32)
            return [new_position, new_ids] + context_incremental.new_states

        while_loop_inputs = [initial_position, inputs] + initial_states
        final_position, outputs = mtf.while_loop(cond_fn, body_fn,
                                                 while_loop_inputs)[:2]
        del final_position
        if has_partial_sequences and remove_partial_sequences:
            # remove partial sequences from outputs
            partial_length = mtf.reduce_sum(mtf.to_int32(
                mtf.not_equal(partial_sequences, 0)),
                                            reduced_dim=length_dim)
            outputs = mtf.dynamic_shift(outputs,
                                        -partial_length,
                                        length_dim,
                                        wrap=False)
        return outputs
Example #11
0
def _top_2_gating(inputs,
                  outer_expert_dims,
                  experts_dim,
                  expert_capacity_dim,
                  hparams,
                  train,
                  variable_dtype,
                  importance=None,
                  name="top_2_gating"):
    """Compute gating for mixture-of-experts in TensorFlow.

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_use_second_place_loss: a boolean
    hparams.moe_second_policy_train: a string
    hparams.moe_second_policy_eval: a string
    hparams.moe_second_threshold: a float

  The returned forward assignment is a tensor used to map (via einsum) from the
  inputs to the expert_inputs.  Likewise, the returned combine_tensor is
  used to map (via einsum) from the expert outputs to the outputs.  Both the
  forward and backward assignments are mostly zeros.  The shapes of the tensors
  are as follows.

  inputs: [<batch_dims>, group_size_dim, input_dim]
  importance: [<batch_dims>, group_size_dim]
  dispatch_tensor:
    [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
  expert_inputs:
    [<batch_dims>, experts_dim, expert_capacity_dim, input_dim]

  expert_outputs: [<batch_dims>, experts_dim, expert_capacity_dim, output_dim]
  combine_tensor:
    [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
  outputs: [<batch_dims>, group_size_dim, output_dim]

  "importance" is an optional tensor with one floating-point value for each
  input vector.  If the importance of an input is 1.0, then we send it to
  up to 2 experts.  If 0.0 < importance < 1.0, then we send it to at most
  one expert.  If importance == 0.0, then we send it to no experts.

  We use "importance" at the second-level gating function of a hierarchical
  mixture of experts.  Inputs to the first-choice expert-group get importance
  1.0.  Inputs to the second-choice expert group get importance 0.5.
  Inputs that represent padding get importance 0.0.

  Args:
    inputs: a mtf.Tensor with shape [<batch_dims>, group_size_dim, input_dim]
    outer_expert_dims: an optional list of dimensions.  This is for the case
      where we are at an inner level of a hierarchical MoE.
    experts_dim: a Dimension (the number of experts)
    expert_capacity_dim: a Dimension (number of examples per group per expert)
    hparams: model hyperparameters.
    train: a boolean
    variable_dtype: a mtf.VariableDType
    importance: an optional tensor with shape [<batch_dims>, group_size_dim]
    name: an optional string

  Returns:
    dispatch_tensor: a Tensor with shape
      [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
    combine_tensor: a Tensor with shape
      [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on illegal hyperparameters
  """
    group_size_dim, unused_input_dim = inputs.shape.dims[-2:]

    raw_gates = mtf.layers.dense(inputs,
                                 experts_dim,
                                 use_bias=False,
                                 expert_dims=outer_expert_dims,
                                 variable_dtype=variable_dtype,
                                 name=name)
    raw_gates = mtf.softmax(raw_gates, experts_dim)

    # The internals of this function run in float32.
    #   bfloat16 seems to reduce quality.
    raw_gates = mtf.to_float(raw_gates)

    expert_capacity_f = float(expert_capacity_dim.size)

    # FIND TOP 2 EXPERTS PER POSITON
    # Find the top expert for each position. shape=[batch, group]
    index_1, gate_1 = mtf.top_1(raw_gates, experts_dim)
    # [batch, group, experts]
    mask_1 = mtf.one_hot(index_1, experts_dim, dtype=raw_gates.dtype)
    density_1_proxy = raw_gates
    if importance is not None:
        mask_1 *= mtf.to_float(mtf.equal(importance, 1.0))
        gate_1 *= mtf.to_float(mtf.equal(importance, 1.0))
        density_1_proxy *= mtf.to_float(mtf.equal(importance, 1.0))
    gates_without_top_1 = raw_gates * (1.0 - mask_1)
    # [batch, group]
    index_2, gate_2 = mtf.top_1(gates_without_top_1, experts_dim)
    # [batch, group, experts]
    mask_2 = mtf.one_hot(index_2, experts_dim, dtype=raw_gates.dtype)
    if importance is not None:
        mask_2 *= mtf.to_float(mtf.greater(importance, 0.0))

    denom = gate_1 + gate_2 + 1e-9
    gate_1 /= denom
    gate_2 /= denom

    # BALANCING LOSSES
    # shape = [batch, experts]
    # We want to equalize the fraction of the batch assigned to each expert
    density_1 = mtf.reduce_mean(mask_1, reduced_dim=group_size_dim)
    # Something continuous that is correlated with what we want to equalize.
    density_1_proxy = mtf.reduce_mean(density_1_proxy,
                                      reduced_dim=group_size_dim)
    loss = (mtf.reduce_mean(density_1_proxy * density_1) *
            float(experts_dim.size * experts_dim.size))

    if hparams.moe_use_second_place_loss:
        # Also add a loss to encourage all experts to be used equally also as the
        # second-place expert.  Experimentally, this seems to be a wash.
        # We want to equalize the fraction of the batch assigned to each expert:
        density_2 = mtf.reduce_mean(mask_2, reduced_dim=group_size_dim)
        # As a proxy for density_2, we renormalize the raw gates after the top one
        # has been removed.
        normalized = gates_without_top_1 / (mtf.reduce_sum(
            gates_without_top_1, reduced_dim=experts_dim) + 1e-9)
        density_2_proxy = mtf.reduce_mean(normalized,
                                          reduced_dim=group_size_dim)
        loss_2 = (mtf.reduce_mean(density_2_proxy * density_2) *
                  float(experts_dim.size * experts_dim.size))
        loss += loss_2 * 0.5

    # Depending on the policy in the hparams, we may drop out some of the
    # second-place experts.
    if train:
        policy = hparams.moe_second_policy_train
        threshold = hparams.moe_second_threshold_train
    else:
        policy = hparams.moe_second_policy_eval
        threshold = hparams.moe_second_threshold_eval
    if policy == "all":
        # Use second-place experts for all examples.
        pass
    elif policy == "none":
        # Never use second-place experts for all examples.
        mask_2 = mtf.zeros_like(mask_2)
    elif policy == "threshold":
        # Use second-place experts if gate_2 > threshold.
        mask_2 *= mtf.to_float(mtf.greater(gate_2, threshold))
    elif policy == "random":
        # Use second-place experts with probablity min(1.0, gate_2 / threshold).
        mask_2 *= mtf.to_float(
            mtf.less(mtf.random_uniform(gate_2.mesh, gate_2.shape),
                     gate_2 / max(threshold, 1e-9)))
    else:
        raise ValueError("Unknown policy %s" % policy)

    # COMPUTE ASSIGNMENT TO EXPERTS
    # [batch, group, experts]
    # This is the position within the expert's mini-batch for this sequence
    position_in_expert_1 = mtf.cumsum(mask_1, group_size_dim,
                                      exclusive=True) * mask_1
    # Remove the elements that don't fit. [batch, group, experts]
    mask_1 *= mtf.to_float(mtf.less(position_in_expert_1, expert_capacity_f))
    # [batch, experts]
    # How many examples in this sequence go to this expert
    mask_1_count = mtf.reduce_sum(mask_1, reduced_dim=group_size_dim)
    # [batch, group] - mostly ones, but zeros where something didn't fit
    mask_1_flat = mtf.reduce_sum(mask_1, reduced_dim=experts_dim)
    # [batch, group]
    position_in_expert_1 = mtf.reduce_sum(position_in_expert_1,
                                          reduced_dim=experts_dim)
    # Weight assigned to first expert.  [batch, group]
    gate_1 *= mask_1_flat

    # [batch, group, experts]
    position_in_expert_2 = (
        mtf.cumsum(mask_2, group_size_dim, exclusive=True) + mask_1_count)
    position_in_expert_2 *= mask_2
    mask_2 *= mtf.to_float(mtf.less(position_in_expert_2, expert_capacity_f))
    # mask_2_count = mtf.reduce_sum(mask_2, reduced_dim=experts_dim)
    mask_2_flat = mtf.reduce_sum(mask_2, reduced_dim=experts_dim)
    gate_2 *= mask_2_flat
    position_in_expert_2 = mtf.reduce_sum(position_in_expert_2,
                                          reduced_dim=experts_dim)

    # [batch, group, experts, expert_capacity]
    combine_tensor = (
        gate_1 * mask_1_flat * mtf.one_hot(index_1, experts_dim) *
        mtf.one_hot(mtf.to_int32(position_in_expert_1), expert_capacity_dim) +
        gate_2 * mask_2_flat * mtf.one_hot(index_2, experts_dim) *
        mtf.one_hot(mtf.to_int32(position_in_expert_2), expert_capacity_dim))

    combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
    loss = mtf.cast(loss, inputs.dtype)

    dispatch_tensor = mtf.cast(mtf.cast(combine_tensor, tf.bool),
                               combine_tensor.dtype)

    return dispatch_tensor, combine_tensor, loss
Example #12
0
 def call(self, context, x, losses=None):
     """Call the layer."""
     params = self.make_params(context)
     q = params.compute_q(x)
     if self.shared_kv:
         kv = params.compute_kv(x)
         k = kv
         v = kv
     else:
         k = params.compute_k(x)
         v = params.compute_v(x)
     if context.mode == "incremental":
         if self.shared_kv:
             prev_kv, = context.get_states(1)
         else:
             prev_k, prev_v = context.get_states(2)
         current_position = mtf.equal(
             mtf.range(context.mesh, self.window_dim, dtype=tf.int32),
             mtf.mod(context.position, self.radius))
         if self.shared_kv:
             kv = mtf.where(current_position,
                            kv,
                            prev_kv,
                            output_shape=prev_kv.shape)
             k = kv
             v = kv
             context.record_new_states([kv])
         else:
             k = mtf.where(current_position,
                           params.compute_k(x),
                           prev_k,
                           output_shape=prev_k.shape)
             v = mtf.where(current_position,
                           params.compute_v(x),
                           prev_v,
                           output_shape=prev_v.shape)
             context.record_new_states([k, v])
         window_pos = mtf.range(context.mesh, self.window_dim, tf.int32)
         visible = mtf.greater_equal(context.position, window_pos)
         bias = attention.visibility_mask_to_attention_bias(
             visible, context.activation_dtype)
         o = attention.attention(
             q, k, v, self.window_dim, self.kv_dim, self.kv_dim, bias,
             **self.attention_kwargs_from_context(context))
     elif context.length_dim.size <= max(256, self.radius * 4):
         # nothing fancy - just do full attention and mask
         memory_length = self.rename_length_to_memory_length(
             context.position, context)
         o = attention.attention(
             q, self.rename_length_to_memory_length(k, context),
             self.rename_length_to_memory_length(v, context),
             self.memory_length(context), self.kv_dim, self.kv_dim,
             self.compute_bias(context, memory_length, x),
             **self.attention_kwargs_from_context(context))
     else:
         # fancy local attention algorithm
         o = attention.local_attention_1d(
             q=q,
             k=k,
             v=None if self.shared_kv else v,
             length_dim=context.length_dim,
             key_dim=self.kv_dim,
             value_dim=self.kv_dim,
             length_dim_num_splits=1,  # TODO(noam): look at the layout
             autoregressive=context.model.fully_autoregressive,
             radius=self.radius,
             sequence_id=context.sequence_id,
             write_priority=context.write_priority,
             read_priority=context.read_priority,
             attention_kwargs=self.attention_kwargs_from_context(context))
     if context.mode == "first_part":
         window_pos = mtf.range(context.mesh, self.window_dim, tf.int32)
         pos = mtf.range(context.mesh, context.length_dim, tf.int32)
         select_recent = mtf.cast(
             mtf.equal(mtf.mod(pos, self.radius), window_pos), x.dtype)
         select_recent *= mtf.cast(mtf.less(pos, context.initial_position),
                                   x.dtype)
         select_recent *= mtf.cast(
             mtf.greater_equal(pos, context.initial_position - self.radius),
             x.dtype)
         state_shape = (k.shape - [context.length_dim, self.kv_dim] +
                        [self.window_dim, self.kv_dim])
         k_state = mtf.einsum([k, select_recent],
                              output_shape=state_shape,
                              reduced_dims=[context.length_dim])
         context.new_states.append(k_state)
         if not self.shared_kv:
             v_state = mtf.einsum([v, select_recent],
                                  output_shape=state_shape,
                                  reduced_dims=[context.length_dim])
             context.new_states.append(v_state)
     return params.compute_output(o, output_shape=x.shape)
Example #13
0
    def compute_bias(self, context, memory_position, x):
        """Compute attention bias.

    Args:
      context: a transformer.Context
      memory_position: an int32 tensor containing memory_length dimension.
      x: a Tensor - the query antecedent - required for relative attention
    Returns:
      a Tensor or None
    """
        min_relative_position = self.min_relative_position(context)
        max_relative_position = self.max_relative_position(context)
        # we can often cache the result of this function between similar layers
        can_cache = (self.relative_attention_type is None
                     or self.relative_attention_type == "bias_shared")
        if can_cache:
            cache_key = ("self_attention_mask", min_relative_position,
                         max_relative_position, self.relative_attention_type,
                         self.num_heads)
            if cache_key in context.cache:
                return context.cache[cache_key]
        biases = []
        relative_position = memory_position - context.position
        if min_relative_position is not None:
            visible = mtf.greater_equal(relative_position,
                                        min_relative_position)
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))
        if max_relative_position is not None:
            visible = mtf.less_equal(relative_position, max_relative_position)
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))
        if context.read_priority is not None:
            visible = mtf.greater_equal(
                context.read_priority,
                mtf.layers.rename_length_to_memory_length(
                    context.write_priority))
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))

        sequence_id = None
        # Subsequence id should only be set if we are in the decoder and have
        # multiple targets per input. This will allow each sub-target to only attend
        # to itself.
        if isinstance(context.subsequence_id, mtf.Tensor):
            sequence_id = context.subsequence_id
        elif isinstance(context.sequence_id, mtf.Tensor):
            sequence_id = context.sequence_id
        if (sequence_id is not None
                and context.length_dim in sequence_id.shape):
            visible = mtf.equal(
                sequence_id,
                self.rename_length_to_memory_length(sequence_id, context))
            biases.append(
                attention.visibility_mask_to_attention_bias(
                    visible, context.activation_dtype))
        if self.relative_attention_type is not None:
            buckets_dim = mtf.Dimension("buckets",
                                        self.relative_attention_num_buckets)
            heads_dim = mtf.Dimension("heads", self.num_heads)
            bidirectional = not context.model.fully_autoregressive
            rp_bucket = _relative_position_bucket(relative_position,
                                                  bidirectional=bidirectional,
                                                  num_buckets=buckets_dim.size)
            if (self.relative_attention_type == "bias"
                    or self.relative_attention_type == "bias_shared"):
                values = mtf.get_variable(context.mesh,
                                          "relative_attention_bias",
                                          [heads_dim, buckets_dim],
                                          dtype=context.variable_dtype)
            elif self.relative_attention_type == "contextual":
                values = layers.dense(x, [buckets_dim, heads_dim],
                                      variable_dtype=context.variable_dtype,
                                      name="relative_attention_contextual")
            else:
                raise ValueError(
                    "unrecognized relative_attention_type \"%s\"" %
                    self.relative_attention_type)
            biases.append(mtf.gather(values, rp_bucket, buckets_dim))
        ret = mtf.add_n(biases) if biases else None
        if can_cache:
            context.cache[cache_key] = ret
        return ret
def local_attention_1d(q,
                       k,
                       v,
                       length_dim,
                       key_dim,
                       value_dim,
                       fully_autoregressive=True,
                       length_dim_num_splits=1,
                       radius=128,
                       sequence_id=1,
                       write_priority=None,
                       read_priority=None,
                       attention_kwargs=None):
  """Attention to the a neighborood around the source.

  If fully_autoregressive, then query position p can only see memory positions
  in the range (p - radius, p].

  If not fully_autoregressive, then query position p can only see memory
  positions in the range (p - window_size, p + radius].

  In addition, if write_priority and read_priority are provided, then attention
  is limited to position pairs where
  read_priority[query position] >= write_priority[memory position]

  Args:
    q: a Tensor containing length_dim
    k: a Tensor containing length_dim
    v: an optional Tensor containing length_dim.  If none then uses v=k.
    length_dim: a Dimension
    key_dim: a Dimension (the channels dimension of q and k)
    value_dim: a Dimension (the channels dimension of v)
    fully_autoregressive: a boolean
    length_dim_num_splits: an optional integer indicating how many ways the
      length dimension is split
    radius: an integer
    sequence_id: a Tensor or an integer
    write_priority: an optional Tensor containing length_dim
    read_priority: an optional Tensor containing length_dim
    attention_kwargs: optional keyword arguments for attention()

  Returns:
    a Tensor with the shape x.shape - key_dim + value_dim

  Raises:
    ValueError: if channels or depth don't match.
  """
  # Choose a suitable block size.
  # We choose the greatest divisor of length_per_split less than or equal
  # to max(window_size, 128)
  length_per_split = length_dim.size // length_dim_num_splits
  block_length = max(radius, 128)
  while length_per_split % block_length != 0:
    block_length -= 1
  query_block_length = mtf.Dimension("query_block_length", block_length)
  memory_block_length = mtf.Dimension("memory_block_length", block_length)
  # The num_blocks dimension gets the same name as the length dimension,
  # so it will be split in the same way.
  num_blocks = mtf.Dimension(length_dim.name, length_dim.size // block_length)
  def _reshape_query(x):
    return mtf.replace_dimensions(
        x, length_dim, [num_blocks, query_block_length])
  def _reshape_memory(x):
    x = mtf.replace_dimensions(
        x, length_dim, [num_blocks, memory_block_length])
    return (mtf.left_halo_exchange if fully_autoregressive
            else mtf.halo_exchange)(
                x, num_blocks, memory_block_length, radius)
  q = _reshape_query(q)
  k = _reshape_memory(k)
  if v:
    v = _reshape_memory(v)
  else:
    v = k
  if sequence_id is None:
    sequence_id = 1
  if (not isinstance(sequence_id, mtf.Tensor) or
      length_dim not in sequence_id.shape.dims):
    sequence_id += mtf.zeros(q.mesh, [length_dim], tf.int32)
  q_sequence_id = _reshape_query(sequence_id)
  m_sequence_id = _reshape_memory(sequence_id)
  pos = mtf.range(q.mesh, length_dim, dtype=tf.int32)
  q_pos = _reshape_query(pos)
  m_pos = _reshape_memory(pos)

  padded_memory_block_length = mtf.Dimension(
      "memory_block_length",
      (1 if fully_autoregressive else 2) * radius + block_length)

  relative_position = m_pos - q_pos
  visible = mtf.equal(q_sequence_id, m_sequence_id)
  visible = mtf.logical_and(visible, mtf.greater(relative_position, -radius))
  visible = mtf.logical_and(visible, mtf.less_equal(
      relative_position, 0 if fully_autoregressive else radius))
  if read_priority is not None:
    write_priority = _reshape_memory(write_priority)
    read_priority = _reshape_query(read_priority)
    visible = mtf.logical_and(
        visible, mtf.greater_equal(read_priority, write_priority))

  bias = visibility_mask_to_attention_bias(visible, q.dtype)
  o = attention(q, k, v, padded_memory_block_length,
                key_dim, value_dim, bias, **attention_kwargs)
  return mtf.replace_dimensions(o, [num_blocks, query_block_length], length_dim)
Example #15
0
def _switch_gating(inputs,
                   outer_expert_dims,
                   experts_dim,
                   expert_capacity_dim,
                   hparams,
                   train,
                   variable_dtype,
                   importance=None,
                   name="switch_gating",
                   num_microbatches=None):
  """Compute a switch top-1 gating with no-token-left behind behavior."""
  # SELECT EXPERT
  if train:
    policy = hparams.moe_rand_1_policy_train
  else:
    policy = hparams.moe_rand_1_policy_eval

  # Input perturbations
  if train and policy == "input_jitter":
    inputs = mtf.layers.multiplicative_jitter(inputs, hparams.moe_rand_1_jitter)

  gate_logits = mtf.layers.dense(
      inputs,
      experts_dim,
      use_bias=False,
      expert_dims=outer_expert_dims,
      variable_dtype=variable_dtype,
      name=name)
  raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim)

  # The internals of this function run in float32.
  #   bfloat16 seems to reduce quality.
  raw_gates = mtf.to_float(raw_gates)

  # Top-k operation
  k_dim = mtf.Dimension("k", hparams.moe_switch_top_k)
  expert_gate, expert_index = mtf.top_k(
      raw_gates, reduced_dim=experts_dim, k_dim=k_dim)
  expert_mask = mtf.one_hot(expert_index, experts_dim)

  # LOAD BALANCING LOSS
  outer_batch_dim = inputs.shape[0]
  batch_dim = inputs.shape[1]
  group_size_dim = inputs.shape[-2]
  density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim)
  density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim)
  if importance is not None:
    expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    density_1_proxy *= mtf.cast(
        mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
  loss = (
      mtf.reduce_mean(density_1_proxy * density_1) *
      float(experts_dim.size * experts_dim.size))
  if num_microbatches and num_microbatches > 1:
    tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
        num_microbatches))
    loss /= num_microbatches

  # Logging
  if train:
    entropy = mtf.reduce_sum(
        -raw_gates * mtf.log(raw_gates + 1e-9), reduced_dim=experts_dim)
    batch_entropy = mtf.reduce_mean(entropy)
    mtf.scalar_summary(name + "/entropy", batch_entropy)

    mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim])
    total_routed = mtf.reduce_sum(mask_count_experts)
    expert_fraction = mtf.to_float(mask_count_experts / total_routed)
    split_fractions = mtf.split(
        expert_fraction,
        split_dim=experts_dim,
        num_or_size_splits=experts_dim.size)
    for fraction in split_fractions:
      mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"),
                         mtf.reduce_mean(fraction))
    mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss))

  # COMPUTE ASSIGNMENT TO EXPERT
  # Iteratively route tokens (no-token-left-behind). The idea is to route as
  # many tokens as possible to top-i before then trying top-(i+1).
  top_k_masks = mtf.split(
      expert_mask, split_dim=k_dim, num_or_size_splits=k_dim.size)
  top_k_gates = mtf.split(
      expert_gate, split_dim=k_dim, num_or_size_splits=k_dim.size)
  top_k_indices = mtf.split(
      expert_index, split_dim=k_dim, num_or_size_splits=k_dim.size)

  # Tensors cumulative values over the iterative process.
  combine_tensor = mtf.constant(
      inputs.mesh,
      value=0,
      shape=[outer_batch_dim, batch_dim, experts_dim, expert_capacity_dim])
  cum_tokens = mtf.constant(
      inputs.mesh, value=0, shape=[outer_batch_dim, batch_dim, experts_dim])
  tokens_left_to_route = mtf.constant(
      inputs.mesh, value=1., shape=[outer_batch_dim, batch_dim, group_size_dim])

  expert_capacity_float = float(expert_capacity_dim.size)
  for (top_i_mask, top_i_gate, top_i_index) in zip(top_k_masks, top_k_gates,
                                                   top_k_indices):
    top_i_mask = mtf.reshape(
        top_i_mask,
        new_shape=[outer_batch_dim, batch_dim, group_size_dim, experts_dim])
    # Operate only on the unrouted tokens.
    top_i_mask *= tokens_left_to_route

    # Record cumulative number of tokens to each expert across iterations.
    cumulative_tokens_in_expert = cum_tokens + mtf.cumsum(
        top_i_mask, group_size_dim)

    expert_overflow = mtf.to_float(
        mtf.less_equal(cumulative_tokens_in_expert, expert_capacity_float))
    output_i_tokens = top_i_mask * expert_overflow

    # Update the cumulative tokens routed to each expert.
    cum_tokens += mtf.reduce_sum(output_i_tokens, reduced_dim=group_size_dim)
    tokens_left_to_route -= (
        mtf.reduce_sum(output_i_tokens, reduced_dim=experts_dim))

    # Combine-tensor for this iteration
    output_i_tokens_flat = mtf.reduce_sum(
        output_i_tokens, reduced_dim=experts_dim)
    position_in_expert = cumulative_tokens_in_expert - 1
    top_i_combine_tensor = (
        top_i_gate * output_i_tokens_flat *
        mtf.one_hot(top_i_index, experts_dim) *
        mtf.one_hot(mtf.to_int32(position_in_expert), expert_capacity_dim))
    combine_tensor += top_i_combine_tensor

  # Match the inputs dtype.
  combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
  loss = mtf.cast(loss, inputs.dtype)
  dispatch_tensor = mtf.cast(
      mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype)

  return dispatch_tensor, combine_tensor, loss
Example #16
0
def unet_with_spatial_partition(mesh, mesh_impl, dataset_str, images, labels):
    """Builds the UNet model graph, train op and eval metrics.

  Args:
    mesh: a MeshTensorflow.mesh object.
    mesh_impl: a mesh implementation, such as SimdMeshImpl and
      PlacementMeshImpl.
    dataset_str: a string of either train or eval. This is used for batch_norm.
    images: a laid out Tensor with shape [batch, x, y, num_channels]
      or [batch, x, y, z, num_channels].
    labels: a laid out Tensor with shape [batch, x, y, num_classes]
      or [batch, x, y, z, num_classes].

  Returns:
    Prediction and loss.
  """

    is_training = (dataset_str == 'train')
    if dataset_str == 'train':
        batch_dim = mtf.Dimension('batch', FLAGS.batch_size_train)
    else:
        assert dataset_str == 'eval'
        batch_dim = mtf.Dimension('batch', FLAGS.batch_size_eval)
    image_nx_dim = mtf.Dimension('image_nx_block', FLAGS.image_nx_block)
    image_ny_dim = mtf.Dimension('image_ny_block', FLAGS.image_ny_block)
    image_sx_dim = mtf.Dimension('image_sx_block',
                                 FLAGS.ct_resolution // FLAGS.image_nx_block)
    image_sy_dim = mtf.Dimension('image_sy_block',
                                 FLAGS.ct_resolution // FLAGS.image_ny_block)
    image_sz_dim = mtf.Dimension('image_sz_block', FLAGS.ct_resolution)
    image_c_dim = mtf.Dimension('image_c', FLAGS.image_c)
    label_c_dim = mtf.Dimension('label_c', FLAGS.label_c)
    mtf_images_shape, mtf_labels_shape = get_input_mtf_shapes(dataset_str)

    mtf_dtype = tf.as_dtype(FLAGS.mtf_dtype)
    variable_dtype = mtf.VariableDType(mtf_dtype, mtf_dtype, mtf_dtype)

    # Import input features.
    x = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(images),
                                   mtf_images_shape)
    x = mtf.cast(x, mtf_dtype)

    # Import ground truth labels.
    t = mtf.import_laid_out_tensor(mesh, mesh_impl.LaidOutTensor(labels),
                                   mtf_labels_shape)
    t = mtf.cast(t, mtf_dtype)

    # Transpose the blocks.
    if FLAGS.sampled_2d_slices:
        x = mtf.transpose(x, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            image_c_dim
        ])

        t = mtf.transpose(t, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            label_c_dim
        ])
    else:
        x = mtf.transpose(x, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            image_sz_dim, image_c_dim
        ])

        t = mtf.transpose(t, [
            batch_dim, image_nx_dim, image_ny_dim, image_sx_dim, image_sy_dim,
            image_sz_dim, label_c_dim
        ])

    # Network.
    levels = []
    all_bn_update_ops = []
    # add levels with convolution or down-sampling
    for depth in range(FLAGS.network_depth):
        for n_conv in range(FLAGS.n_conv_per_block):
            if depth == 0 and n_conv == 0:
                # no dropout in 1st layer.
                dropout_keep_p = 1.0
            else:
                dropout_keep_p = FLAGS.dropout_keep_p
            x, bn_update_ops = conv_with_spatial_partition(
                x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim,
                FLAGS.n_base_filters * (2**depth), dropout_keep_p,
                FLAGS.with_batch_norm, is_training,
                'conv_{}_{}'.format(depth, n_conv), variable_dtype,
                'conv_down_{}_{}'.format(depth, n_conv))
            all_bn_update_ops.extend(bn_update_ops)
        levels.append(x)

        if depth < FLAGS.network_depth - 1:
            if FLAGS.sampled_2d_slices:
                x = mtf.layers.max_pool2d(x, ksize=(2, 2))
            else:
                x = mtf.layers.max_pool3d(x, ksize=(2, 2, 2))

    # add levels with up-convolution or up-sampling
    for depth in range(FLAGS.network_depth - 1)[::-1]:
        x = deconv_with_spatial_partition(
            x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim,
            FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p,
            'conv_{}_{}'.format(depth, FLAGS.n_conv_per_block - 1),
            variable_dtype, 'deconv_{}_0'.format(depth))
        x = mtf.concat([x, levels[depth]],
                       concat_dim_name='conv_{}_{}'.format(
                           depth, FLAGS.n_conv_per_block - 1))

        for n_conv in range(FLAGS.n_conv_per_block):
            x, bn_update_ops = conv_with_spatial_partition(
                x, FLAGS.sampled_2d_slices, image_nx_dim, image_ny_dim,
                FLAGS.n_base_filters * (2**depth), FLAGS.dropout_keep_p,
                FLAGS.with_batch_norm, is_training,
                'conv_{}_{}'.format(depth, n_conv), variable_dtype,
                'conv_up_{}_{}'.format(depth, n_conv))
            all_bn_update_ops.extend(bn_update_ops)

    # no dropout in the final layer.
    if FLAGS.sampled_2d_slices:
        y = mtf.layers.conv2d_with_blocks(
            x,
            mtf.Dimension('label_c', FLAGS.label_c),
            filter_size=(1, 1),
            strides=(1, 1),
            padding='SAME',
            h_blocks_dim=image_nx_dim,
            w_blocks_dim=image_ny_dim,
            variable_dtype=variable_dtype,
            name='final_conv_{}'.format(FLAGS.label_c),
        )
    else:
        y = mtf.layers.conv3d_with_blocks(
            x,
            mtf.Dimension('label_c', FLAGS.label_c),
            filter_size=(1, 1, 1),
            strides=(1, 1, 1),
            padding='SAME',
            d_blocks_dim=image_nx_dim,
            h_blocks_dim=image_ny_dim,
            variable_dtype=variable_dtype,
            name='final_conv_{}'.format(FLAGS.label_c),
        )

    # use mtf.constant to make sure there is no CPU-side constants.
    def scalar(v, dtype):
        return mtf.constant(mesh, v, shape=[], dtype=dtype)

    argmax_t = mtf.argmax(t, label_c_dim)
    liver_t = mtf.cast(mtf.equal(argmax_t, scalar(1, tf.int32)), mtf_dtype)
    lesion_t = mtf.cast(mtf.equal(argmax_t, scalar(2, tf.int32)), mtf_dtype)

    argmax_y = mtf.argmax(y, label_c_dim)
    lesion_y = mtf.cast(mtf.equal(argmax_y, scalar(2, tf.int32)), mtf_dtype)

    # summary of class ratios.
    lesion_pred_ratio = mtf.reduce_mean(lesion_y)
    lesion_label_ratio = mtf.reduce_mean(lesion_t)

    # summary of accuracy.
    accuracy = mtf.reduce_mean(
        mtf.cast(mtf.equal(argmax_y, argmax_t), mtf_dtype))

    # Cross-entropy loss. Up-weight the liver region.
    pixel_loss = mtf.layers.softmax_cross_entropy_with_logits(
        y, t, label_c_dim)
    pixel_weight = scalar(1, mtf_dtype) + \
        liver_t * scalar(FLAGS.xen_liver_weight - 1, mtf_dtype) + \
        lesion_t * scalar(FLAGS.xen_lesion_weight - FLAGS.xen_liver_weight,
                          mtf_dtype)
    loss_xen = mtf.reduce_mean(pixel_loss * pixel_weight)

    # Dice loss
    y_prob = mtf.softmax(y, reduced_dim=label_c_dim)
    lesion_prob = mtf.reduce_sum(mtf.slice(y_prob, 2, 1, 'label_c'),
                                 reduced_dim=mtf.Dimension('label_c', 1))
    prob_intersect = mtf.reduce_sum(lesion_prob * lesion_t,
                                    output_shape=mtf.Shape([batch_dim]))
    prob_area_sum = mtf.reduce_sum(lesion_prob + lesion_t,
                                   output_shape=mtf.Shape([batch_dim]))
    loss_dice_per_case = mtf.reduce_mean(
        scalar(-2, mtf_dtype) * prob_intersect /
        (prob_area_sum + scalar(FLAGS.dice_epsilon, mtf_dtype)))
    loss_dice_global = scalar(-2, mtf_dtype) * mtf.reduce_sum(
        prob_intersect) / (mtf.reduce_sum(prob_area_sum) +
                           scalar(FLAGS.dice_epsilon, mtf_dtype))

    loss_dice = (loss_dice_per_case + loss_dice_global) * scalar(
        0.5, mtf_dtype)

    loss = scalar(FLAGS.dice_loss_weight, mtf_dtype) * loss_dice + scalar(
        1 - FLAGS.dice_loss_weight, mtf_dtype) * loss_xen

    intersect = mtf.reduce_sum(lesion_y * lesion_t,
                               output_shape=mtf.Shape([batch_dim]))
    area_sum = mtf.reduce_sum(lesion_y + lesion_t,
                              output_shape=mtf.Shape([batch_dim]))
    # summary of dice.
    dice_per_case = mtf.reduce_mean(
        scalar(2, mtf_dtype) * intersect /
        (area_sum + scalar(0.000001, mtf_dtype)))
    dice_global = scalar(2, mtf_dtype) * mtf.reduce_sum(intersect) / (
        mtf.reduce_sum(area_sum) + scalar(0.000001, mtf_dtype))

    eval_metrics = {
        'lesion_pred_ratio': lesion_pred_ratio,
        'lesion_label_ratio': lesion_label_ratio,
        'accuracy_of_all_classes': accuracy,
        'lesion_dice_per_case': dice_per_case,
        'lesion_dice_global': dice_global,
        'loss_xen': loss_xen,
        'loss_dice': loss_dice,
        'loss_dice_per_case': loss_dice_per_case,
        'loss_dice_global': loss_dice_global,
    }

    if FLAGS.sampled_2d_slices:
        y_prob_downsampled = mtf.layers.avg_pool2d(
            y_prob, ksize=(FLAGS.pred_downsample, ) * 2)
        if FLAGS.output_ground_truth:
            lesion_gt_downsampled = mtf.layers.avg_pool2d(
                mtf.slice(t, 2, 1, 'label_c'),
                ksize=(FLAGS.pred_downsample, ) * 2)
    else:
        y_prob_downsampled = mtf.layers.avg_pool3d(
            y_prob, ksize=(FLAGS.pred_downsample, ) * 3)
        if FLAGS.output_ground_truth:
            lesion_gt_downsampled = mtf.layers.avg_pool3d(
                mtf.slice(t, 2, 1, 'label_c'),
                ksize=(FLAGS.pred_downsample, ) * 3)

    liver_prob_downsampled = mtf.slice(y_prob_downsampled, 1, 1, 'label_c')
    lesion_prob_downsampled = mtf.slice(y_prob_downsampled, 2, 1, 'label_c')
    preds = [
        mtf.reduce_sum(liver_prob_downsampled,
                       reduced_dim=mtf.Dimension('label_c', 1)),
        mtf.reduce_sum(lesion_prob_downsampled,
                       reduced_dim=mtf.Dimension('label_c', 1))
    ]

    if FLAGS.output_ground_truth:
        preds.append(
            mtf.reduce_sum(lesion_gt_downsampled,
                           reduced_dim=mtf.Dimension('label_c', 1)))

    preds.extend([intersect, area_sum])

    return preds, loss, eval_metrics, all_bn_update_ops
Example #17
0
def sample_autoregressive(
    partial_sequences,
    other_features,
    params,
    stop_at_token=50256,
    max_steps=None,
    temperature=0.9,
    variable_dtype=mtf.VariableDType(tf.float32),
    encoder_output=None,
    encoder_sequence_id=None,
    encoder_inputs=None,
    shared_params=None,
    has_partial_sequences=True,
    encoder_layer_outputs=None,
    never_end=False,
    remove_partial_sequences=False,
    sampling_keep_top_k=-1,
    bos_id=50256,
):
    """Sample randomly one token at a time.

    The partial_sequences represent partial sequences to be continued.  The
    first tokens of each sequence are nonzero representing the given partial
    sequences and the last tokens of each sequence are zeros, representing what
    needs to be filled in.

    If there are no partial sequences (you want to sample from the beginning),
    then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and
    has_partial_sequences=False (so we can skip computation).

    Args:
        partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim]
        stop_at_token: an optional integer eos id.  Stop when we produce it.
        max_steps: an optional integer, the max number of steps to decode.
        temperature: an optional floating point value between 0.0 and 1.0 0.0
        means argmax, 1.0 means sample according to predicted distribution.
        variable_dtype: a mtf.VariableDType
        encoder_output: an optional Tensor
        encoder_sequence_id: an optional Tensor
        encoder_inputs: an optional Tensor
        shared_params: an optional dictionary
        has_partial_sequences: a boolean
        encoder_layer_outputs: optional - readonly list of tensor activations when
        decoding, one per each input layer + the embedding layer
        never_end: a boolean - if set, then avoid generating stop_at_token
        remove_partial_sequences: a boolean - whether to remove the partial
        sequences from the output
        sampling_keep_top_k: an integer - if not -1, only sample from the top k
        logits.
        bos_id: beginning of sequence id

    Returns:
        a Tensor with shape [<batch_dims>, length_dim]
    """

    inputs = partial_sequences  # Partial sequences to fill in
    batch_dims = inputs.shape.dims[:-1]
    length_dim = inputs.shape.dims[-1]
    padding_id = params.get("padding_id", 0)
    slow_sampling = params.get("slow_sampling", False)

    initial_position = mtf.reduce_sum(
        mtf.to_int32(mtf.not_equal(inputs, padding_id)),
        reduced_dim=length_dim)  # Gets position where zero padding starts

    length_range = mtf.range(inputs.mesh, length_dim, tf.int32)
    input_full_attention = True  # for now hardcode this to true bc lazy
    if input_full_attention:
        # Vanilla autoregressive model - each position can see previous positions.
        # Think this feeds in to the loop fn and tells each position where it can attend to?
        read_priority = write_priority = length_range * mtf.to_int32(
            mtf.greater(length_range, initial_position))
    else:
        read_priority = write_priority = length_range

    # Builds context to pass around internally
    # The 'first part' context records initial states of k / v / x

    if not slow_sampling:
        context_first_part = mtf_transformer.transformer.Context(
            model=None,
            mesh=inputs.mesh,
            batch_dims=batch_dims,
            length_dim=length_dim,
            variable_dtype=variable_dtype,
            mode="first_part",
            position=length_range,
            position_is_default=True,
            new_states=[],
            initial_position=initial_position,
            sequence_id=None,
            encoder_output=encoder_output,
            encoder_sequence_id=encoder_sequence_id,
            constant_states=[],
            shared_params=shared_params,
            encoder_layer_outputs=encoder_layer_outputs,
            write_priority=write_priority,
            read_priority=read_priority,
            inputs=inputs,
            encoder_inputs=encoder_inputs)

        with tf.variable_scope("gpt2"):
            logits, _, _ = gpt2.model({"inputs": inputs},
                                      other_features,
                                      params,
                                      inputs.mesh,
                                      variable_dtype=variable_dtype,
                                      context=context_first_part)

        if not has_partial_sequences:
            initial_states = [
                mtf.zeros_like(t) for t in context_first_part.new_states
            ]
        else:
            initial_states = context_first_part.new_states
    else:
        initial_states = []

    if not has_partial_sequences:
        partial_sequences_eos_count = 0

    if stop_at_token is not None:
        partial_sequences_eos_count = mtf.reduce_sum(mtf.to_int32(
            mtf.equal(partial_sequences, stop_at_token)),
                                                     reduced_dim=length_dim)

    def cond_fn(position, ids, *unused_states):
        """Should we run another loop iteration?"""
        past_end = mtf.greater_equal(position, length_dim.size)
        if max_steps:
            past_end = mtf.logical_or(
                past_end,
                mtf.greater_equal(position - initial_position, max_steps))

        is_done = past_end
        if stop_at_token is not None:
            eos_count = mtf.reduce_sum(mtf.to_int32(
                mtf.equal(ids, stop_at_token)),
                                       reduced_dim=length_dim)
            has_additional_eos = mtf.greater(eos_count,
                                             partial_sequences_eos_count)
            is_done = mtf.logical_or(is_done, has_additional_eos)
        all_done = mtf.reduce_all(is_done)
        return mtf.logical_not(all_done)

    def body_fn(position, ids, *states):
        """One step in the decode loop."""
        nonlocal sampling_keep_top_k

        context = mtf_transformer.transformer.Context(
            model=None,
            mesh=inputs.mesh,
            batch_dims=batch_dims,
            length_dim=length_dim,
            variable_dtype=variable_dtype,
            mode="incremental",
            position=position,
            position_is_default=True,
            states=states,
            new_states=[],
            initial_position=position,
            sequence_id=None,
            encoder_output=encoder_output,
            encoder_sequence_id=encoder_sequence_id,
            shared_params=shared_params,
            encoder_layer_outputs=encoder_layer_outputs,
            write_priority=write_priority,
            read_priority=read_priority,
            inputs=ids,
            encoder_inputs=encoder_inputs) if not slow_sampling else None

        with tf.variable_scope("gpt2", reuse=tf.AUTO_REUSE):
            logits, _, _ = gpt2.model({"inputs": ids},
                                      other_features,
                                      params,
                                      inputs.mesh,
                                      variable_dtype=variable_dtype,
                                      context=context)

        # By default, do top_k sampling of 0.9
        if sampling_keep_top_k == -2:
            sampling_keep_top_k = int(logits.shape[-1].size * 0.1)

        if sampling_keep_top_k != -1:
            if sampling_keep_top_k <= 0:
                raise ValueError(
                    "sampling_keep_top_k must either be -1 or positive.")
            k_largest = mtf.nth_largest_element(
                logits,
                n=sampling_keep_top_k,
                reduced_dim=other_features["vocab_dim"])
            logits = mtf.where(mtf.less_equal(logits, k_largest),
                               mtf.ones_like(logits) * -1e6, logits)

        ids_this_step = mtf.sample_with_temperature(
            logits, other_features["vocab_dim"], temperature)

        if slow_sampling:
            ids_this_step = mtf.shift(ids_this_step,
                                      offset=1,
                                      dim=length_dim,
                                      wrap=False)
        else:
            ids_this_step = mtf.reshape(ids_this_step, (batch_dims))

        one_hot = mtf.one_hot(position, length_dim, dtype=tf.int32)
        one_new_id = ids_this_step * one_hot
        new_ids = (1 - one_hot) * ids + one_new_id
        new_position = position + 1

        ret = [new_position, new_ids]
        if context is not None:
            ret += context.new_states
        return ret

    while_loop_inputs = [initial_position, inputs] + initial_states
    final_position, outputs = mtf.while_loop(cond_fn, body_fn,
                                             while_loop_inputs)[:2]
    del final_position
    if has_partial_sequences and remove_partial_sequences:
        # Remove partial sequences from outputs
        partial_length = mtf.reduce_sum(mtf.to_int32(
            mtf.not_equal(partial_sequences, padding_id)),
                                        reduced_dim=length_dim)
        outputs = mtf.dynamic_shift(outputs,
                                    -partial_length,
                                    length_dim,
                                    wrap=False)
    return outputs
Example #18
0
def _rand_1_gating(
    inputs, outer_expert_dims, experts_dim, expert_capacity_dim,
    hparams, train, variable_dtype, importance=None, name="rand_1_gating",
    num_microbatches=None):
  """Compute a random top-1 gating."""
  # SELECT EXPERT
  if train:
    policy = hparams.moe_rand_1_policy_train
  else:
    policy = hparams.moe_rand_1_policy_eval

  # The internals of this function run in float32.
  #   bfloat16 seems to reduce quality.
  gate_inputs = mtf.to_float(inputs)

  # Input perturbations
  if train and policy == "input_dropout":
    gate_inputs = mtf.dropout(gate_inputs, 1.0 - hparams.moe_rand_1_dropout)
  elif train and policy == "input_jitter":
    gate_inputs = mtf.layers.multiplicative_jitter(gate_inputs,
                                                   hparams.moe_rand_1_jitter)

  gate_logits = mtf.layers.dense(
      gate_inputs,
      experts_dim,
      use_bias=False,
      expert_dims=outer_expert_dims,
      variable_dtype=variable_dtype,
      name=name)
  raw_gates = mtf.softmax(gate_logits, reduced_dim=experts_dim)

  if policy == "argmax" or policy == "input_dropout" or policy == "input_jitter":
    expert_gate, expert_index = mtf.top_1(raw_gates, reduced_dim=experts_dim)
  elif policy == "sample":
    expert_index = mtf.sample_with_temperature(
        gate_logits, experts_dim, temperature=hparams.moe_rand_1_temperature)
    expert_gate = mtf.gather(raw_gates, expert_index, dim=experts_dim)
  else:
    raise ValueError("Unknown rand_1 policy %s" % policy)

  expert_mask = mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype)

  # LOAD BALANCING LOSS
  # TODO(liamfedus): Check entropy loss.
  group_size_dim = inputs.shape[-2]
  density_1 = mtf.reduce_mean(expert_mask, reduced_dim=group_size_dim)
  density_1_proxy = mtf.reduce_mean(raw_gates, reduced_dim=group_size_dim)
  if importance is not None:
    expert_mask *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
    density_1_proxy *= mtf.cast(
        mtf.equal(importance, 1.0), dtype=raw_gates.dtype)
  loss = (
      mtf.reduce_mean(density_1_proxy * density_1) *
      float(experts_dim.size * experts_dim.size))
  if num_microbatches and num_microbatches > 1:
    tf.logging.info("Dividing load-balance loss by num_microbatches={}".format(
        num_microbatches))
    loss /= num_microbatches

  # Logging
  if train:
    entropy = mtf.reduce_sum(-raw_gates * mtf.log(raw_gates + 1e-9),
                             reduced_dim=experts_dim)
    batch_entropy = mtf.reduce_mean(entropy)
    mtf.scalar_summary(name + "/entropy", batch_entropy)

    mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim])
    total_routed = mtf.reduce_sum(mask_count_experts)
    expert_fraction = mtf.to_float(mask_count_experts / total_routed)
    split_fractions = mtf.split(
        expert_fraction,
        split_dim=experts_dim,
        num_or_size_splits=experts_dim.size)
    for fraction in split_fractions:
      mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"),
                         mtf.reduce_mean(fraction))
    mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss))

  # COMPUTE ASSIGNMENT TO EXPERT
  # Experts have a limited capacity, ensure we do not exceed it. Construct
  # the batch indices, to each expert, with position_in_expert
  position_in_expert = mtf.cumsum(
      expert_mask, group_size_dim, exclusive=True) * expert_mask
  position_in_expert = mtf.cast(position_in_expert, dtype=raw_gates.dtype)
  # Keep only tokens that fit within expert_capacity.
  expert_capacity_float = float(expert_capacity_dim.size)
  expert_mask *= mtf.cast(
      mtf.less(position_in_expert, expert_capacity_float),
      dtype=raw_gates.dtype)
  expert_mask_flat = mtf.reduce_sum(expert_mask, reduced_dim=experts_dim)

  # Mask out the experts that have overflowed expert capacity. Sparsify the
  # expert_gate.
  expert_gate *= expert_mask_flat

  combine_tensor = (
      expert_gate * expert_mask_flat *
      mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype) *
      mtf.one_hot(
          mtf.to_int32(position_in_expert),
          expert_capacity_dim,
          dtype=raw_gates.dtype))

  # Match the inputs dtype.
  combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
  loss = mtf.cast(loss, inputs.dtype)
  dispatch_tensor = mtf.cast(
      mtf.cast(combine_tensor, tf.bool), combine_tensor.dtype)

  return dispatch_tensor, combine_tensor, loss