Beispiel #1
0
def sample_categorical(x, dim=None):
    dim = x.shape[-1] if dim is None else dim

    cdf = mtf.cumsum(x, dim)
    rand_uniform = mtf.random_uniform(x.mesh, x.shape - dim, minval=0, maxval=1)
    mask = mtf.cast(mtf.greater(cdf, rand_uniform), tf.int32)
    return mtf.argmax(mask, dim)
    def _noisy_targets(self, targets, losses=None):
        """Generate noisy targets for denoising models.

    Args:
      targets: a Tensor
      losses: an optional list onto which to append traning losses
    Returns:
      a Tensor the same dtype and shape as Targets
    """
        hparams = self._hparams
        if hparams.mode == tf.estimator.ModeKeys.TRAIN:
            nt_train = self._noisy_targets_from_spec(
                targets, hparams.noising_spec_train, losses=losses)
            if hparams.noising_use_eval_during_train > 0:
                nt_eval = self._noisy_targets_from_spec(
                    targets, hparams.noising_spec_eval)
                use_eval_noising = mtf.less(
                    mtf.random_uniform(targets.mesh,
                                       targets.shape - self.length_dim),
                    hparams.noising_use_eval_during_train)
                nt_train = mtf.where(use_eval_noising, nt_eval, nt_train)
            return nt_train
        else:
            return self._noisy_targets_from_spec(targets,
                                                 hparams.noising_spec_eval)
def benchmark_model(mesh):
    """
  Initializes a 3D volume with random noise, and execute a forward FFT
  """
    batch_dim = mtf.Dimension("batch", FLAGS.batch_size)
    x_dim = mtf.Dimension("nx", FLAGS.cube_size)
    y_dim = mtf.Dimension("ny", FLAGS.cube_size)
    z_dim = mtf.Dimension("nz", FLAGS.cube_size)

    tx_dim = mtf.Dimension("tnx", FLAGS.cube_size)
    ty_dim = mtf.Dimension("tny", FLAGS.cube_size)
    tz_dim = mtf.Dimension("tnz", FLAGS.cube_size)

    # Create field
    field = mtf.random_uniform(mesh, [batch_dim, x_dim, y_dim, z_dim])

    # Apply FFT
    fft_field = mpm.fft3d(mtf.cast(field, tf.complex64),
                          [tx_dim, ty_dim, tz_dim])

    # Inverse FFT
    rfield = mtf.cast(mpm.ifft3d(fft_field, [x_dim, y_dim, z_dim]), tf.float32)

    # Compute errors
    err = mtf.reduce_max(mtf.abs(field - rfield))
    return err
 def _noisy_targets_from_spec(self, targets, noising_spec, losses=None):
     if noising_spec["type"] == "mask":
         # Replace a randomly-chosen noising_spec["prob"] of input tokens with 0.
         return targets * mtf.cast(
             mtf.greater(mtf.random_uniform(targets.mesh, targets.shape),
                         noising_spec["prob"]), targets.dtype)
     elif noising_spec["type"] == "random_zipfian":
         # Replace a randomly-chosen noising_spec["prob"] of input tokens.
         # Rather than drawing the replacement tokens uniformly, we sample from
         #   a distribution favoring lower token-ids, assuming that the ids have
         #   been assigned in frequency order.  The probability of choosing an
         #   id is proportional to 1/(id+10)
         logits = mtf.log(1.0 / (mtf.range(
             targets.mesh, self.targets_vocab_dim, dtype=tf.float32) +
                                 10.0))
         logits = mtf.broadcast(logits,
                                new_shape=targets.shape + logits.shape)
         r = mtf.sample_with_temperature(logits, self.targets_vocab_dim)
         use_noise = mtf.less(
             mtf.random_uniform(targets.mesh, targets.shape),
             noising_spec["prob"])
         return mtf.where(use_noise, r, targets)
     elif noising_spec["type"] == "transformer":
         # Train a small transformer to fill in masked out values, then
         # sample from it.
         hparams = self._hparams
         if hparams.mode != tf.estimator.ModeKeys.TRAIN:
             raise NotImplementedError("Not implemented")
         noiser_hparams = copy.copy(self._hparams)
         noiser_hparams.del_hparam("mode")
         noiser_hparams.override_from_dict(noising_spec["overrides"])
         with tf.variable_scope("noiser"):
             noiser = MtfTransformer(noiser_hparams,
                                     mode=hparams.mode,
                                     problem_hparams=self._problem_hparams)
             logits, loss = noiser._mtf_model_fn(  # pylint: disable=protected-access
                 self._original_features, targets.mesh)
             samples = mtf.sample_with_temperature(logits,
                                                   self.targets_vocab_dim)
         losses.append(loss)
         return samples
     else:
         raise ValueError("unknown noising spec %s" % noising_spec)
 def _noisy_targets_from_spec(self, targets, noising_spec, losses=None):
   if noising_spec["type"] == "mask":
     # Replace a randomly-chosen noising_spec["prob"] of input tokens with 0.
     return targets * mtf.cast(
         mtf.greater(mtf.random_uniform(targets.mesh, targets.shape),
                     noising_spec["prob"]), targets.dtype)
   elif noising_spec["type"] == "random_zipfian":
     # Replace a randomly-chosen noising_spec["prob"] of input tokens.
     # Rather than drawing the replacement tokens uniformly, we sample from
     #   a distribution favoring lower token-ids, assuming that the ids have
     #   been assigned in frequency order.  The probability of choosing an
     #   id is proportional to 1/(id+10)
     logits = mtf.log(1.0 / (mtf.range(
         targets.mesh, self.targets_vocab_dim, dtype=tf.float32) + 10.0))
     logits = mtf.broadcast(logits, new_shape=targets.shape + logits.shape)
     r = mtf.sample_with_temperature(logits, self.targets_vocab_dim)
     use_noise = mtf.less(
         mtf.random_uniform(targets.mesh, targets.shape), noising_spec["prob"])
     return mtf.where(use_noise, r, targets)
   elif noising_spec["type"] == "transformer":
     # Train a small transformer to fill in masked out values, then
     # sample from it.
     hparams = self._hparams
     if hparams.mode != tf.estimator.ModeKeys.TRAIN:
       raise NotImplementedError("Not implemented")
     noiser_hparams = copy.copy(self._hparams)
     noiser_hparams.del_hparam("mode")
     noiser_hparams.override_from_dict(noising_spec["overrides"])
     with tf.variable_scope("noiser"):
       noiser = MtfTransformer(
           noiser_hparams,
           mode=hparams.mode,
           problem_hparams=self._problem_hparams)
       logits, loss = noiser._mtf_model_fn(  # pylint: disable=protected-access
           self._original_features, targets.mesh)
       samples = mtf.sample_with_temperature(logits, self.targets_vocab_dim)
     losses.append(loss)
     return samples
   else:
     raise ValueError("unknown noising spec %s" % noising_spec)
    def add_position_timing_signal_func(self, context, x, step):
        """Add n-dimensional embedding as the position (horizontal) timing signal.

    Args:
      context: mtf context
      x: a tensor with shape [batch, length, depth]
      step: step

    Returns:
      a Tensor with the same shape as x.

    """

        if not self.position_start_index:
            index = 0

        elif self.position_start_index == "random":
            # Shift all positions randomly
            # TODO(dehghani): What would be reasonable for max number of shift?
            index = mtf.random_uniform(context.mesh, [],
                                       maxval=x.shape.dims[1].size,
                                       dtype=tf.int32)

        elif self.position_start_index == "step":
            # Shift positions based on the step
            if self.recurrence_type == "act":
                num_steps = self.act_max_steps
            else:
                num_steps = self.num_rec_steps
            index = mtf.cast(x.shape.dims[1].size * step / num_steps,
                             dtype=tf.int32)

        length = context.length_dim
        channels = context.model.model_dim
        signal = self.get_timing_signal_1d(context,
                                           length,
                                           channels,
                                           start_index=index)

        if self.add_or_concat_timing_signal == "add":
            x_with_timing = x + mtf.cast(signal, x.dtype)
        # Unimplemented
        if self.add_or_concat_timing_signal == "concat":
            batch_dim = x.shape.dims[0]
            out_shape = mtf.Shape([batch_dim] + signal.shape.dims[1:])
            signal_tiled = mtf.broadcast(signal, out_shape)
            x_with_timing = mtf.concat(
                (x, signal_tiled),
                concat_dim_name=signal_tiled.dimension_names[-1])

        return x_with_timing
Beispiel #7
0
def model_fn(nc=64, batch_size=1):
    """
    Example of function implementing a CNN and returning a value.
    """

    # Create the mesh TF graph
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")

    # Define the named dimensions
    n_block_x = 4
    n_block_y = 2
    n_block_z = 1

    batch_dim = mtf.Dimension("batch", batch_size`)

    nx_dim = mtf.Dimension('nx_block', n_block_x)
    ny_dim = mtf.Dimension('ny_block', n_block_y)
    nz_dim = mtf.Dimension('nz_block', n_block_z)

    sx_dim = mtf.Dimension('sx_block', nc//n_block_x)
    sy_dim = mtf.Dimension('sy_block', nc//n_block_y)
    sz_dim = mtf.Dimension('sz_block', nc//n_block_z)

    image_c_dim = mtf.Dimension('image_c', 3)
    hidden_dim  = mtf.Dimension('h', 128)

    # Create some input data
    data = mtf.random_uniform(mesh, [batch_dim, nx_dim, ny_dim, nz_dim,
                                                sx_dim, sy_dim, sz_dim,
                                                image_c_dim])

    net = mtf.layers.conv3d_with_blocks(data, hidden_dim,
        filter_size=(3, 3, 3), strides=(1, 1, 1), padding='SAME',
        d_blocks_dim=nx_dim, h_blocks_dim=ny_dim)

    net = mtf.reduce_sum(net, output_shape=[batch_dim, hidden_dim] )

    return net
  def _noisy_targets(self, targets, losses=None):
    """Generate noisy targets for denoising models.

    Args:
      targets: a Tensor
      losses: an optional list onto which to append traning losses
    Returns:
      a Tensor the same dtype and shape as Targets
    """
    hparams = self._hparams
    if hparams.mode == tf.estimator.ModeKeys.TRAIN:
      nt_train = self._noisy_targets_from_spec(
          targets, hparams.noising_spec_train, losses=losses)
      if hparams.noising_use_eval_during_train > 0:
        nt_eval = self._noisy_targets_from_spec(
            targets, hparams.noising_spec_eval)
        use_eval_noising = mtf.less(
            mtf.random_uniform(targets.mesh, targets.shape - self.length_dim),
            hparams.noising_use_eval_during_train)
        nt_train = mtf.where(use_eval_noising, nt_eval, nt_train)
      return nt_train
    else:
      return self._noisy_targets_from_spec(targets, hparams.noising_spec_eval)
Beispiel #9
0
def _top_2_gating(inputs,
                  outer_expert_dims,
                  experts_dim,
                  expert_capacity_dim,
                  hparams,
                  train,
                  variable_dtype,
                  importance=None,
                  name="top_2_gating"):
    """Compute gating for mixture-of-experts in TensorFlow.

  Note: until the algorithm and inferface solidify, we pass in a hyperparameters
  dictionary in order not to complicate the interface in mtf_transformer.py .
  Once this code moves out of "research", we should pass the hyperparameters
  separately.

  Hyperparameters used:
    hparams.moe_use_second_place_loss: a boolean
    hparams.moe_second_policy_train: a string
    hparams.moe_second_policy_eval: a string
    hparams.moe_second_threshold: a float

  The returned forward assignment is a tensor used to map (via einsum) from the
  inputs to the expert_inputs.  Likewise, the returned combine_tensor is
  used to map (via einsum) from the expert outputs to the outputs.  Both the
  forward and backward assignments are mostly zeros.  The shapes of the tensors
  are as follows.

  inputs: [<batch_dims>, group_size_dim, input_dim]
  importance: [<batch_dims>, group_size_dim]
  dispatch_tensor:
    [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
  expert_inputs:
    [<batch_dims>, experts_dim, expert_capacity_dim, input_dim]

  expert_outputs: [<batch_dims>, experts_dim, expert_capacity_dim, output_dim]
  combine_tensor:
    [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
  outputs: [<batch_dims>, group_size_dim, output_dim]

  "importance" is an optional tensor with one floating-point value for each
  input vector.  If the importance of an input is 1.0, then we send it to
  up to 2 experts.  If 0.0 < importance < 1.0, then we send it to at most
  one expert.  If importance == 0.0, then we send it to no experts.

  We use "importance" at the second-level gating function of a hierarchical
  mixture of experts.  Inputs to the first-choice expert-group get importance
  1.0.  Inputs to the second-choice expert group get importance 0.5.
  Inputs that represent padding get importance 0.0.

  Args:
    inputs: a mtf.Tensor with shape [<batch_dims>, group_size_dim, input_dim]
    outer_expert_dims: an optional list of dimensions.  This is for the case
      where we are at an inner level of a hierarchical MoE.
    experts_dim: a Dimension (the number of experts)
    expert_capacity_dim: a Dimension (number of examples per group per expert)
    hparams: model hyperparameters.
    train: a boolean
    variable_dtype: a mtf.VariableDType
    importance: an optional tensor with shape [<batch_dims>, group_size_dim]
    name: an optional string

  Returns:
    dispatch_tensor: a Tensor with shape
      [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
    combine_tensor: a Tensor with shape
      [<batch_dims>, group_size_dim, experts_dim, expert_capacity_dim]
    loss: a mtf scalar

  Raises:
    ValueError: on illegal hyperparameters
  """
    group_size_dim, unused_input_dim = inputs.shape.dims[-2:]

    raw_gates = mtf.layers.dense(inputs,
                                 experts_dim,
                                 use_bias=False,
                                 expert_dims=outer_expert_dims,
                                 variable_dtype=variable_dtype,
                                 name=name)
    raw_gates = mtf.softmax(raw_gates, experts_dim)

    # The internals of this function run in float32.
    #   bfloat16 seems to reduce quality.
    raw_gates = mtf.to_float(raw_gates)

    expert_capacity_f = float(expert_capacity_dim.size)

    # FIND TOP 2 EXPERTS PER POSITON
    # Find the top expert for each position. shape=[batch, group]
    index_1, gate_1 = mtf.top_1(raw_gates, experts_dim)
    # [batch, group, experts]
    mask_1 = mtf.one_hot(index_1, experts_dim, dtype=raw_gates.dtype)
    density_1_proxy = raw_gates
    if importance is not None:
        mask_1 *= mtf.to_float(mtf.equal(importance, 1.0))
        gate_1 *= mtf.to_float(mtf.equal(importance, 1.0))
        density_1_proxy *= mtf.to_float(mtf.equal(importance, 1.0))
    gates_without_top_1 = raw_gates * (1.0 - mask_1)
    # [batch, group]
    index_2, gate_2 = mtf.top_1(gates_without_top_1, experts_dim)
    # [batch, group, experts]
    mask_2 = mtf.one_hot(index_2, experts_dim, dtype=raw_gates.dtype)
    if importance is not None:
        mask_2 *= mtf.to_float(mtf.greater(importance, 0.0))

    denom = gate_1 + gate_2 + 1e-9
    gate_1 /= denom
    gate_2 /= denom

    # BALANCING LOSSES
    # shape = [batch, experts]
    # We want to equalize the fraction of the batch assigned to each expert
    density_1 = mtf.reduce_mean(mask_1, reduced_dim=group_size_dim)
    # Something continuous that is correlated with what we want to equalize.
    density_1_proxy = mtf.reduce_mean(density_1_proxy,
                                      reduced_dim=group_size_dim)
    loss = (mtf.reduce_mean(density_1_proxy * density_1) *
            float(experts_dim.size * experts_dim.size))

    if hparams.moe_use_second_place_loss:
        # Also add a loss to encourage all experts to be used equally also as the
        # second-place expert.  Experimentally, this seems to be a wash.
        # We want to equalize the fraction of the batch assigned to each expert:
        density_2 = mtf.reduce_mean(mask_2, reduced_dim=group_size_dim)
        # As a proxy for density_2, we renormalize the raw gates after the top one
        # has been removed.
        normalized = gates_without_top_1 / (mtf.reduce_sum(
            gates_without_top_1, reduced_dim=experts_dim) + 1e-9)
        density_2_proxy = mtf.reduce_mean(normalized,
                                          reduced_dim=group_size_dim)
        loss_2 = (mtf.reduce_mean(density_2_proxy * density_2) *
                  float(experts_dim.size * experts_dim.size))
        loss += loss_2 * 0.5

    # Depending on the policy in the hparams, we may drop out some of the
    # second-place experts.
    if train:
        policy = hparams.moe_second_policy_train
        threshold = hparams.moe_second_threshold_train
    else:
        policy = hparams.moe_second_policy_eval
        threshold = hparams.moe_second_threshold_eval
    if policy == "all":
        # Use second-place experts for all examples.
        pass
    elif policy == "none":
        # Never use second-place experts for all examples.
        mask_2 = mtf.zeros_like(mask_2)
    elif policy == "threshold":
        # Use second-place experts if gate_2 > threshold.
        mask_2 *= mtf.to_float(mtf.greater(gate_2, threshold))
    elif policy == "random":
        # Use second-place experts with probablity min(1.0, gate_2 / threshold).
        mask_2 *= mtf.to_float(
            mtf.less(mtf.random_uniform(gate_2.mesh, gate_2.shape),
                     gate_2 / max(threshold, 1e-9)))
    else:
        raise ValueError("Unknown policy %s" % policy)

    # COMPUTE ASSIGNMENT TO EXPERTS
    # [batch, group, experts]
    # This is the position within the expert's mini-batch for this sequence
    position_in_expert_1 = mtf.cumsum(mask_1, group_size_dim,
                                      exclusive=True) * mask_1
    # Remove the elements that don't fit. [batch, group, experts]
    mask_1 *= mtf.to_float(mtf.less(position_in_expert_1, expert_capacity_f))
    # [batch, experts]
    # How many examples in this sequence go to this expert
    mask_1_count = mtf.reduce_sum(mask_1, reduced_dim=group_size_dim)
    # [batch, group] - mostly ones, but zeros where something didn't fit
    mask_1_flat = mtf.reduce_sum(mask_1, reduced_dim=experts_dim)
    # [batch, group]
    position_in_expert_1 = mtf.reduce_sum(position_in_expert_1,
                                          reduced_dim=experts_dim)
    # Weight assigned to first expert.  [batch, group]
    gate_1 *= mask_1_flat

    # [batch, group, experts]
    position_in_expert_2 = (
        mtf.cumsum(mask_2, group_size_dim, exclusive=True) + mask_1_count)
    position_in_expert_2 *= mask_2
    mask_2 *= mtf.to_float(mtf.less(position_in_expert_2, expert_capacity_f))
    # mask_2_count = mtf.reduce_sum(mask_2, reduced_dim=experts_dim)
    mask_2_flat = mtf.reduce_sum(mask_2, reduced_dim=experts_dim)
    gate_2 *= mask_2_flat
    position_in_expert_2 = mtf.reduce_sum(position_in_expert_2,
                                          reduced_dim=experts_dim)

    # [batch, group, experts, expert_capacity]
    combine_tensor = (
        gate_1 * mask_1_flat * mtf.one_hot(index_1, experts_dim) *
        mtf.one_hot(mtf.to_int32(position_in_expert_1), expert_capacity_dim) +
        gate_2 * mask_2_flat * mtf.one_hot(index_2, experts_dim) *
        mtf.one_hot(mtf.to_int32(position_in_expert_2), expert_capacity_dim))

    combine_tensor = mtf.cast(combine_tensor, inputs.dtype)
    loss = mtf.cast(loss, inputs.dtype)

    dispatch_tensor = mtf.cast(mtf.cast(combine_tensor, tf.bool),
                               combine_tensor.dtype)

    return dispatch_tensor, combine_tensor, loss
        def ut_function(state, step, halting_probability, remainders,
                        n_updates, previous_state):
            """implements act (position-wise halting).

      Args:
        state: 3-D Tensor: [batch_size, length, channel]
        step: indicates number of steps taken so far
        halting_probability: halting probability
        remainders: act remainders
        n_updates: act n_updates
        previous_state: previous state

      Returns:
        transformed_state: transformed state
        step: step+1
        halting_probability: halting probability
        remainders: act remainders
        n_updates: act n_updates
        new_state: new state
      """
            state = self.step_preprocess(context, state, step)

            if self.act_type == "random":
                # random as halting probability
                p = mtf.random_uniform(context.mesh,
                                       shape=halting_probability.shape.dims,
                                       dtype=context.variable_dtype)
            else:
                last_dim_name = state.shape.dimension_names[-1]
                new_dims = [mtf.Dimension(last_dim_name, 1)]
                with tf.variable_scope("sigmoid_activation_for_pondering",
                                       reuse=tf.AUTO_REUSE):
                    p = mtf.layers.dense(state,
                                         variable_dtype=context.variable_dtype,
                                         reduced_dims=[state.shape.dims[-1]],
                                         new_dims=new_dims,
                                         activation=mtf.sigmoid,
                                         use_bias=True)
                    if self.act_type == "global":
                        # average over all positions (as a global halting prob)
                        p = mtf.reduce_mean(p, reduced_dim=p.shape.dims[1])
                        p = mtf.squeeze(p)
                    else:
                        # maintain position-wise probabilities
                        new_shape = p.shape.dims[:-1]
                        p = mtf.reshape(p, new_shape)
            # Mask for inputs which have not halted yet
            still_running = mtf.cast(mtf.less(halting_probability, 1.0),
                                     context.activation_dtype)

            # Mask of inputs which halted at this step
            new_halted = mtf.cast(
                mtf.greater(halting_probability + p * still_running,
                            threshold),
                context.activation_dtype) * still_running
            # Mask of inputs which haven't halted, and didn't halt this step
            still_running = mtf.cast(
                mtf.less_equal(halting_probability + p * still_running,
                               threshold),
                context.activation_dtype) * still_running

            # Add the halting probability for this step to the halting
            # probabilities for those input which haven't halted yet
            halting_probability += p * still_running

            # Compute remainders for the inputs which halted at this step
            remainders += new_halted * (1 - halting_probability)

            # Add the remainders to those inputs which halted at this step
            halting_probability += new_halted * remainders

            # Increment n_updates for all inputs which are still running
            n_updates += still_running + new_halted

            # Compute the weight to be applied to the new state and output
            # 0 when the input has already halted
            # p when the input hasn't halted yet
            # the remainders when it halted this step
            input_tensor = p * still_running + new_halted * remainders
            update_weights = input_tensor

            # apply transformation on the state
            transformed_state = state

            for _ in range(self.num_inrecurrence_layers):
                transformed_state = self.vanilla_transformer_layer(
                    context, transformed_state, mask)

            # update running part in the weighted state and keep the rest
            new_state = ((transformed_state * update_weights) +
                         (previous_state * (1 - update_weights)))

            if self.act_type == "accumulated":
                # Add in the weighted state
                new_state = (transformed_state *
                             update_weights) + previous_state

            step += 1

            return (transformed_state, step, halting_probability, remainders,
                    n_updates, new_state)