def testWithoutDefault(self):
        """Tests parameters without defaults."""
        p = _TestTFDataInputWithoutDefault.Params()
        self.assertIn('args', p)
        self.assertIn('begin', p.args)
        self.assertIn('end', p.args)
        self.assertIsNone(p.args.begin)
        self.assertEqual(p.args.end, 10)

        p.args.begin = 0
        ig = p.Instantiate()
        self.assertIsInstance(ig, _TestTFDataInputWithoutDefault)

        with self.session(graph=tf.get_default_graph()) as sess:
            data = ig.GetPreprocessedInputBatch()
            self.assertIsInstance(data, tf.Tensor)
            self.assertAllEqual(data.shape, ())
            self.assertEqual(data.dtype, tf.int32)

            # Consumes all data.
            for i in range(p.args.begin, p.args.end):
                self.assertEqual(sess.run(data), i)

            with self.assertRaises(tf.errors.OutOfRangeError):
                sess.run(data)
    def testToFromProto(self):
        """Similar to `testExample` but params will be restored from a proto."""
        serialized_proto = _TestTFDataInput.Params().ToProto()
        p = hyperparams.Params.FromProto(serialized_proto)
        self.assertIn('args', p)
        self.assertIn('begin', p.args)
        self.assertIn('end', p.args)
        self.assertEqual(p.args.begin, 0)
        self.assertEqual(p.args.end, 10)

        ig = p.Instantiate()
        self.assertIsInstance(ig, _TestTFDataInput)

        with self.session(graph=tf.get_default_graph()) as sess:
            data = ig.GetPreprocessedInputBatch()
            self.assertIsInstance(data, tf.Tensor)
            self.assertAllEqual(data.shape, ())
            self.assertEqual(data.dtype, tf.int32)

            # Consumes all data.
            for i in range(p.args.begin, p.args.end):
                self.assertEqual(sess.run(data), i)

            with self.assertRaises(tf.errors.OutOfRangeError):
                sess.run(data)
Ejemplo n.º 3
0
    def testDatasetV2(self):
        """Tests the TFv2 Dataset."""
        p = _TestTFDataInputV2.Params()
        self.assertIn('args', p)
        self.assertIn('begin', p.args)
        self.assertIn('end', p.args)
        self.assertEqual(p.args.begin, 0)
        self.assertEqual(p.args.end, 10)

        ig = p.Instantiate()
        self.assertIsInstance(ig, _TestTFDataInputV2)

        # We keep the TFv1's Session here since v1/v2 behaviors would not coexist.
        # TODO(oday): write TFv2-specific tests.
        with self.session(graph=tf.get_default_graph()) as sess:
            sess.run(ig.InitOps())
            data = ig.GetPreprocessedInputBatch()
            self.assertIsInstance(data, py_utils.NestedMap)
            self.assertIsInstance(data.value, tf.Tensor)
            self.assertAllEqual(data.value.shape, ())
            self.assertEqual(data.value.dtype, tf.int32)

            # Consumes all data.
            for i in range(p.args.begin, p.args.end):
                self.assertEqual(sess.run(data).value, i)

            with self.assertRaises(tf.errors.OutOfRangeError):
                sess.run(data)
Ejemplo n.º 4
0
def SetOverWriteGlobalStep(tensor, graph=None):
  graph = graph or tf.get_default_graph()
  mb_tensors = graph.get_collection_ref(_OVERWRITE_GLOBAL_STEP_COLLECTION)
  if len(mb_tensors) == 1:
    mb_tensors[0] = tensor
  else:
    graph.add_to_collection(_OVERWRITE_GLOBAL_STEP_COLLECTION, tensor)
Ejemplo n.º 5
0
 def _CreateVariableStub(name,
                         params,
                         reuse=None,
                         trainable=True,
                         collections=None,
                         default_seed=None,
                         synchronization=None,
                         aggregation=None):
     """Return a zero tensor of the right shape instead of creating variable."""
     del reuse
     del default_seed
     del synchronization
     del aggregation
     dtype = params.dtype
     shape = py_utils.ToStaticShape(params.shape)
     # For total samples counters we have to actually create variables so that
     # we can access the 'value' attribute during construction.
     if 'total_samples' in name:
         var = tf.get_variable(name,
                               shape,
                               dtype,
                               tf.constant_initializer(0),
                               collections=collections,
                               trainable=trainable,
                               validate_shape=True)
     else:
         key = (tf.get_default_graph(), tuple(shape))
         if key in variable_cache:
             var = variable_cache[key]
         else:
             var = tf.zeros(shape, dtype)
             variable_cache[key] = var
     return var, var
Ejemplo n.º 6
0
    def testDatasetV1(self):
        """Tests the TFv1 Dataset."""
        p = _TestTFDataInputV1.Params()
        self.assertIn('args', p)
        self.assertIn('begin', p.args)
        self.assertIn('end', p.args)
        self.assertEqual(p.args.begin, 0)
        self.assertEqual(p.args.end, 10)

        ig = p.Instantiate()
        self.assertIsInstance(ig, _TestTFDataInputV1)

        with self.session(graph=tf.get_default_graph()) as sess:
            sess.run(ig.InitOps())
            data = ig.GetPreprocessedInputBatch()
            self.assertIsInstance(data, py_utils.NestedMap)
            self.assertIsInstance(data.value, tf.Tensor)
            self.assertAllEqual(data.value.shape, ())
            self.assertEqual(data.value.dtype, tf.int32)

            # Consumes all data.
            for i in range(p.args.begin, p.args.end):
                self.assertEqual(sess.run(data).value, i)

            with self.assertRaises(tf.errors.OutOfRangeError):
                sess.run(data)
Ejemplo n.º 7
0
    def testWithBoundMethod(self):
        """Tests pipeline defined by a bound method: member function with self."""
        p = _TestTFDataInputWithBoundMethod.Params()
        self.assertIn('args', p)
        self.assertNotIn('begin', p.args)
        self.assertIn('end', p.args)
        self.assertEqual(p.args.end, 10)

        ig = p.Instantiate()

        self.assertIsInstance(ig, _TestTFDataInputWithBoundMethod)
        with self.session(graph=tf.get_default_graph()) as sess:
            sess.run(ig.InitOps())
            data = ig.GetPreprocessedInputBatch()
            self.assertIsInstance(data, py_utils.NestedMap)
            self.assertIsInstance(data.value, tf.Tensor)
            self.assertAllEqual(data.value.shape, ())
            self.assertEqual(data.value.dtype, tf.int32)

            # Consumes all data.
            for i in range(p.args.end):
                self.assertEqual(sess.run(data).value, i)

            with self.assertRaises(tf.errors.OutOfRangeError):
                sess.run(data)
Ejemplo n.º 8
0
  def __init__(self, dtype, shape, send_device, recv_device, name=None):
    """Construct a channel.

    Args:
      dtype: The dtype of tensors sent through the channel.
      shape: The shape of tensors sent through the channel. Must be a fully
        defined shape for TPUs.
      send_device: A fully-specified tensorflow device.
      recv_device: A fully-specified tensorflow device.
      name: A name for the channel (optional).
    """
    current_graph = tf.get_default_graph()
    assert current_graph, "A channel is scoped within a tf.Graph"
    self._dtype = dtype
    self._send_device = send_device
    self._recv_device = recv_device
    self._name = current_graph.unique_name(name if name else "channel")

    assert shape is not None
    shape = tf.TensorShape(shape)

    self._shape = shape
    self._send_tpu_core = _TpuCore(send_device)
    self._recv_tpu_core = _TpuCore(recv_device)
    self._send_called = False
    self._recv_op = None
    assert ((self._send_tpu_core == -1) == (self._recv_tpu_core == -1)), (
        "Mixing TPU and non-TPU: %s and %s" % (send_device, recv_device))
    if self._send_tpu_core >= 0:
      assert self._shape.is_fully_defined(), (
          "TPU channel must have fully defined shape. Name: %s, shape: %s" %
          (self._name, self._shape))
      assert self._send_tpu_core != self._recv_tpu_core, (
          "TPU send/recv must be cross-core: %s and %s" %
          (send_device, recv_device))
Ejemplo n.º 9
0
def scalar(name, value, while_loop_reduce='mean'):
    """Adds summary scalar.

  Outside of tpu_summary.context() does nothing.

  Args:
    name: string name
    value: scalar tensor value
    while_loop_reduce: optional argument, determines what to do when this
      summary appears inside a tf.while_loop. Can be 'mean' or 'sum'.

  Raises:
    RuntimeError: if the function is called in Eager mode.
  """
    if py_utils.IsEagerMode():
        raise RuntimeError(EAGER_MODE_EXCEPTION_STR)

    assert while_loop_reduce in ('mean', 'sum')
    ctx = TpuSummaryContext.current()
    if ctx is None:
        return
    x = TpuSummaryScalar()
    x.name = str(name)
    x.value = tf.convert_to_tensor(value)
    if x.value.shape != ():  # pylint: disable=g-explicit-bool-comparison
        raise ValueError('use tpu_summary.tensor() instead: %r' % value)
    x.name_scope = tf.get_default_graph().get_name_scope()
    x.while_loop_reduce = while_loop_reduce
    ctx.summary_tensors.append(x)
Ejemplo n.º 10
0
    def testWithMapArgs(self):
        """Tests the `map_args` parameter."""
        p = _TestTFDataInputWithMapArgs.Params()
        self.assertIn('args', p)
        self.assertIn('num_samples', p)  # Defined by BaseInputGenerator.
        self.assertIn('begin', p.args)
        self.assertNotIn('end', p.args)
        self.assertEqual(p.num_samples, 0)
        self.assertEqual(p.args.begin, 0)

        p.num_samples = 20
        ig = p.Instantiate()
        self.assertIsInstance(ig, _TestTFDataInputWithMapArgs)

        with self.session(graph=tf.get_default_graph()) as sess:
            data = ig.GetPreprocessedInputBatch()
            self.assertIsInstance(data, tf.Tensor)
            self.assertAllEqual(data.shape, ())
            self.assertEqual(data.dtype, tf.int32)

            # Consumes all data.
            for i in range(p.args.begin, p.num_samples):
                self.assertEqual(sess.run(data), i)

            with self.assertRaises(tf.errors.OutOfRangeError):
                sess.run(data)
Ejemplo n.º 11
0
def GetOverWriteGlobalStep(graph=None):
  graph = graph or tf.get_default_graph()
  mb_tensors = graph.get_collection_ref(_OVERWRITE_GLOBAL_STEP_COLLECTION)
  if len(mb_tensors) == 1:
    mb_tensor = mb_tensors[0]
  else:
    mb_tensor = py_utils.GetGlobalStep()
  return mb_tensor
Ejemplo n.º 12
0
def pw_tensor(name, value):
    """Adds summary tensor."""
    ctx = TpuSummaryContext.current()
    if ctx is None:
        return
    x = PwTpuSummaryTensor()
    x.name = str(name)
    x.value = tf.convert_to_tensor(value)
    x.name_scope = tf.get_default_graph().get_name_scope()
    ctx.summary_tensors.append(x)
Ejemplo n.º 13
0
def tensor(name, value):
  """Adds summary tensor. Similar to scalar() but allows other shapes."""
  ctx = TpuSummaryContext.current()
  if ctx is None:
    return
  x = TpuSummaryScalar()
  x.name = str(name)
  x.value = tf.convert_to_tensor(value)
  x.name_scope = tf.get_default_graph().get_name_scope()
  x.while_loop_reduce = 'stack'
  ctx.summary_tensors.append(x)
Ejemplo n.º 14
0
def pw_tensor(name, value):
    """Adds summary tensor."""
    if py_utils.IsEagerMode():
        raise RuntimeError(EAGER_MODE_EXCEPTION_STR)
    ctx = TpuSummaryContext.current()
    if ctx is None:
        return
    x = PwTpuSummaryTensor()
    x.name = str(name)
    x.value = tf.convert_to_tensor(value)
    x.name_scope = tf.get_default_graph().get_name_scope()
    ctx.summary_tensors.append(x)
Ejemplo n.º 15
0
 def testTfRecordFileLargeBatch(self):
     p = ToyInputGenerator.Params()
     p.batch_size = 200
     self._tmpdir, p.input_files = _CreateFakeTFRecordFiles()
     p.dataset_type = tf.data.TFRecordDataset
     p.randomize_order = False
     p.parallel_readers = 1
     ig = p.Instantiate()
     with self.session(graph=tf.get_default_graph()) as sess:
         inputs = ig.GetPreprocessedInputBatch()
         eval_inputs = sess.run(inputs)
         input_shapes = eval_inputs.Transform(lambda t: t.shape)
         self.assertEqual(input_shapes.audio, (200, 48000))
Ejemplo n.º 16
0
def tensor(name, value):
    """Adds summary tensor. Similar to scalar() but allows other shapes."""
    if py_utils.IsEagerMode():
        raise RuntimeError(EAGER_MODE_EXCEPTION_STR)
    ctx = TpuSummaryContext.current()
    if ctx is None:
        return
    x = TpuSummaryScalar()
    x.name = str(name)
    x.value = tf.convert_to_tensor(value)
    x.name_scope = tf.get_default_graph().get_name_scope()
    x.while_loop_reduce = 'stack'
    ctx.summary_tensors.append(x)
Ejemplo n.º 17
0
 def _GetSaver(self):
   """Returns a saver."""
   assert tf.get_default_graph() == self._graph
   p = self.params
   if p.is_eval and self._model.ema:
     tf.logging.info('Using EMA for evaluation.')
     return tf.train.Saver(self._model.ema.variables_to_restore())
   tp = p.train
   return tf.train.Saver(
       sharded=True,
       max_to_keep=tp.save_max_to_keep,
       keep_checkpoint_every_n_hours=tp.save_keep_checkpoint_every_n_hours,
       pad_step_number=True,  # %08d
       write_version=saver_pb2.SaverDef.V2)
Ejemplo n.º 18
0
 def testNumEpochs(self):
   p = ToyInputGenerator.Params()
   p.batch_size = 3
   p.num_epochs = 7
   self._tmpdir, p.input_files = _CreateFakeTFRecordFiles(
       record_count=p.batch_size)
   p.dataset_type = tf.data.TFRecordDataset
   p.randomize_order = False
   p.parallel_readers = 1
   ig = p.Instantiate()
   with self.session(graph=tf.get_default_graph()) as sess:
     inputs = ig.GetPreprocessedInputBatch()
     for _ in range(p.num_epochs):
       eval_inputs = sess.run(inputs)
       self.assertEqual(eval_inputs.audio.shape, (p.batch_size, 48000))
     with self.assertRaisesRegex(tf.errors.OutOfRangeError, 'End of sequence'):
       sess.run(inputs)
Ejemplo n.º 19
0
def merge_all():
    """Returns all summary tensors as a dict of {name: tensor}.

  Note that this is not the same return type as tf.summary.merge_all
  which returns a serialized proto string.

  Outside of tpu_summary.context() returns {}
  """
    ctx = TpuSummaryContext.current()
    if ctx is None:
        return {}
    g = tf.get_default_graph()
    ret = {}
    for x in ctx.summary_tensors:
        if x.value.graph is not g:
            raise ValueError('Tensor %r %r is not an element of this graph.' %
                             (x.name, x.value))
        ret['%s/%s' % (x.name, x.name_scope)] = x.value
    return ret
Ejemplo n.º 20
0
def merge_all_pw_tensor():
    """Returns all summary tensors as a dict of {name: tensor}.

  Note this function only returns summary tensors of type PwTpuSummaryTensor.

  Outside of tpu_summary.context() returns {}
  """
    ctx = TpuSummaryContext.current()
    if ctx is None:
        return {}
    g = tf.get_default_graph()
    ret = {}
    for x in ctx.summary_tensors:
        if isinstance(x, PwTpuSummaryTensor):
            # Only keep summaries of the desired type.
            if x.value.graph is not g:
                raise ValueError(
                    'Tensor %r %r is not an element of this graph.' %
                    (x.name, x.value))
            name = ('%s/%s' % (x.name_scope, x.name)).replace('/', '__')
            ret[name] = x.value
    return ret
Ejemplo n.º 21
0
    def testWithRepeat(self):
        """Tests if the repeated dataset runs forever."""
        p = _TestTFDataInputWithRepeat.Params()
        self.assertIn('args', p)
        self.assertIn('begin', p.args)
        self.assertIn('end', p.args)
        self.assertEqual(p.args.begin, 0)
        self.assertEqual(p.args.end, 10)

        ig = p.Instantiate()
        self.assertIsInstance(ig, _TestTFDataInputWithRepeat)

        with self.session(graph=tf.get_default_graph()) as sess:
            data = ig.GetPreprocessedInputBatch()
            self.assertIsInstance(data, tf.Tensor)
            self.assertAllEqual(data.shape, ())
            self.assertEqual(data.dtype, tf.int32)

            # Runs the dataset several times: it should not raise OutOfRangeError.
            for _ in range(3):
                for i in range(p.args.begin, p.args.end):
                    self.assertEqual(sess.run(data), i)
Ejemplo n.º 22
0
    def testExample(self):
        """Tests the example code in the function docstring."""
        p = _TestTFDataInput.Params()
        self.assertIn('args', p)
        self.assertIn('begin', p.args)
        self.assertIn('end', p.args)
        self.assertEqual(p.args.begin, 0)
        self.assertEqual(p.args.end, 10)

        ig = p.Instantiate()
        self.assertIsInstance(ig, _TestTFDataInput)

        with self.session(graph=tf.get_default_graph()) as sess:
            data = ig.GetPreprocessedInputBatch()
            self.assertIsInstance(data, tf.Tensor)
            self.assertAllEqual(data.shape, ())
            self.assertEqual(data.dtype, tf.int32)

            # Consumes all data.
            for i in range(p.args.begin, p.args.end):
                self.assertEqual(sess.run(data), i)

            with self.assertRaises(tf.errors.OutOfRangeError):
                sess.run(data)
Ejemplo n.º 23
0
def Top2GatingOnLogits(inputs,
                       paddings,
                       logits,
                       num_devices,
                       experts_dim,
                       expert_capacity_dim,
                       fprop_dtype,
                       use_xla_sharding=True,
                       second_expert_policy='all',
                       second_expert_threshold=0.0,
                       legacy_mtf_behavior=True,
                       capacity_factor=None):
  """Computes Top-2 gating for Mixture-of-Experts.

  There are two expected usages of this function:

  1. used with xla_sharding. In this case, 'inputs' corresponds to a sharded
     tensor across multiple tpu cores. The operations within this function are
     automatically sharded/replicated across tpu cores.
  2. used within ML-Pathways. In this case, 'inputs' is always local to one tpu
     core. All computations below are carried out on one tpu core only. This
     function tries to dispatch examples across tpu cores in such a way that
     each expert is assigned no more than 'expert_capacity_dim' number of
     examples.

  Below ` indicates common way of splitting along mesh dimension.

  Dimensions cheat sheet:

    G: group_dim
    S: group_size_dim
    E: number of experts
    C: capacity per expert
    M: model_dim (same as input_dim, same as output_dim)
    B: original batch_dim
    L: original sequence_length_dim

  Note that for local_dispatch original batch BLM is reshaped into GSM, each
  group `g = 0...G-1` is being dispatched independently.

  Args:
    inputs: G`SM Tensor.
    paddings: G`S Tensor.
    logits: G`SE Tensor.
    num_devices: number of MoE devices for local dispatch
    experts_dim: number of experts.
    expert_capacity_dim: number of examples per minibatch(group) per expert.
      Each example is typically a vector of size input_dim, representing
      embedded token or an element of Transformer layer output.
    fprop_dtype: activations datatype to use.
    use_xla_sharding: bool, True if this function is used for the xla_sharding
      case.
    second_expert_policy: 'all', 'sampling' or 'random'.

      - 'all': we greedily pick the 2nd expert.
      - 'sampling': we sample the 2nd expert from the softmax.
      - 'random': we optionally 'random'-ize dispatch to second-best expert
        proportional to (weight / second_expert_threshold).

    second_expert_threshold: threshold for probability normalization for
      second_expert_policy == 'random'.
    legacy_mtf_behavior: bool, True if to match legacy mtf behavior exactly.
    capacity_factor: if set, increases expert_capacity_dim to at least
      (group_size * capacity_factor) / experts_dim
      where `group_size` is the size of G dimension of `inputs`. If the
      value of expert_capacity_dim is already big enough no change is made.

  TODO(lepikhin): get rid of the legacy_mtf_behavior flag.

  Returns:
    A tuple (aux_loss, combine_tensor, dispatch_tensor).

    - aux_loss: auxiliary loss, for equalizing the expert assignment ratios.
    - combine_tensor: G`SEC Tensor for combining expert outputs.
    - dispatch_tensor: G`SEC Tensor, scattering/dispatching inputs to
      experts.
  """
  del inputs  # inputs is currently not used.
  raw_gates = tf.nn.softmax(logits)  # along E dim

  if capacity_factor is not None:
    # Determine expert capacity automatically depedning on the input size.
    group_size_dim = int(logits.shape[1])
    auto_expert_capacity = int((group_size_dim * capacity_factor) / experts_dim)
    if expert_capacity_dim < auto_expert_capacity:
      expert_capacity_dim = auto_expert_capacity
      # Round up to a multiple of 4 to avoid possible padding.
      while expert_capacity_dim % 4:
        expert_capacity_dim += 1
      tf.logging.info(
          'Setting expert_capacity_dim=%r (capacity_factor=%r '
          'group_size_dim=%r experts_dim=%r name_scope=%r)',
          expert_capacity_dim, capacity_factor, group_size_dim, experts_dim,
          tf.get_default_graph().get_name_scope())
    tpu_summary.scalar('expert_capacity', expert_capacity_dim)

  # top first and second gate value and expert index for each input
  #
  # GSK Tensors, K=2
  def _MaybeSplit(x):
    if use_xla_sharding:
      return Split(x, 0, num_devices)
    else:
      return x

  def _CreateOverCapacityRatioSummary(mask, position_in_expert, capacity, name):
    over_capacity = tf.reduce_sum(
        tf.cast(
            tf.greater_equal(mask * position_in_expert, capacity), mask.dtype))
    over_capacity_ratio = over_capacity / tf.reduce_sum(mask)
    py_utils.AddTpuSummaryTensor(name, over_capacity_ratio)
    tpu_summary.scalar(name, over_capacity_ratio, while_loop_reduce='mean')

  # As pointed out by zhifengc@ this method needs to be refactored. lepikhin@
  # and krikun@ will:
  #   - expand moe_spmd_test to compare Adafactor updates, slots on TPU
  #   including 2x2 with sharding
  #
  #   - add more tests for policy="random"
  #
  #   - add single step test for full size WMT model on CPU
  #
  # and then break this function into modules.
  #
  # GS
  index_1 = tf.math.argmax(raw_gates, axis=-1, output_type=tf.int32)
  index_1 = _MaybeSplit(index_1)
  tpu_summary.tensor('index_1', index_1)

  # GSE
  mask_1 = tf.one_hot(index_1, experts_dim, dtype=fprop_dtype)
  mask_1 = _MaybeSplit(mask_1)
  density_1_proxy = raw_gates

  importance = tf.ones_like(mask_1[:, :, 0])

  if paddings is not None:
    importance = 1.0 - paddings
    mask_1 *= tf.expand_dims(importance, -1)
    density_1_proxy *= tf.expand_dims(importance, -1)

  gate_1 = tf.einsum('GSE,GSE->GS', raw_gates, mask_1)
  gates_without_top_1 = raw_gates * (1.0 - mask_1)

  if second_expert_policy == 'sampling':
    # We directly sample the 2nd expert index from the softmax over of the 2nd
    # expert by getting rid of the 1st expert already selected above. To do so,
    # we set a very negative value to the logit corresponding to the 1st expert.
    # Then we sample from the softmax (categorical) distribution using the
    # Gumbel max trick.
    noise = _MaybeSplit(tf.random.uniform(logits.shape, dtype=logits.dtype))
    # Generates standard Gumbel(0, 1) noise, GSE Tensors
    noise = -tf.math.log(-tf.math.log(noise))
    very_negative_logits = _MaybeSplit(
        (tf.ones_like(logits) * logits.dtype.max *
         tf.constant(-0.7, dtype=logits.dtype)))
    # Gets rid of the first expert by setting its logit to be very negative
    updated_logits = _MaybeSplit(
        tf.where(mask_1 > 0.0, very_negative_logits, logits))
    # Adds the Gumbel noise to the updated logits
    noised_logits = _MaybeSplit(updated_logits + noise)
    # Picks the index of the largest noised logit as the 2nd expert. This is
    # equivalent to sampling from the softmax over the 2nd experts.
    index_2 = tf.math.argmax(noised_logits, axis=-1, output_type=tf.int32)
  else:
    index_2 = tf.math.argmax(gates_without_top_1, axis=-1, output_type=tf.int32)

  index_2 = _MaybeSplit(index_2)
  mask_2 = tf.one_hot(index_2, experts_dim, dtype=fprop_dtype)
  mask_2 = _MaybeSplit(mask_2)
  if paddings is not None:
    mask_2 *= tf.expand_dims(importance, -1)
  gate_2 = tf.einsum('GSE,GSE->GS', gates_without_top_1, mask_2)

  if legacy_mtf_behavior:
    # cl/298510175 moved this branch for gate_{1,2} denom calculation here.
    #
    # For policy=random, it's better to nomalize gate_{1,2} before taking
    # capacity  into account and before potentially dropping second expert.
    #
    # According to mean_xent (http://short/_NzbZ5rINr5):
    #   MoE_512_102xen_PolicyAll_298510175
    #   MoE_512_102xen_PolicyRandom_298510175
    #
    # vs pre-cl/298510175
    #   MoE_512_102xen_PolicyRandom
    #   MoE_512_102xen_PolicyAll
    #
    # it substantially improves policy=random with threshold=0.5 which
    # historically was better than policy="all"
    #
    # Also confirmed this by decoding
    #   nmt_train/m4/data/es_en/test.txt
    #   nmt_train/m4/data/ru_en/test.txt
    #   nmt_train/m4/data/zh_en/test.txt
    # and improving BLEU
    #
    # moe_decode.MoE_512_102xen_PolicyRandom_298510175-160000.batch1024.beam4.c_dim4.ln0.8.rkv.mteval102
    #   0.421443
    #   0.327102
    #   0.315693
    # vs
    # moe_decode.feb18_non_fig_snapshot_2626_MoE_512_102xen_PolicyRandom-190000.batch1024.beam4.c_dim4.ln0.8.rkv.mteval102
    #   0.399232
    #   0.310606
    #   0.288229
    #
    # Additional comparison, see mean_xent http://short/_YHccOhQtdu with
    # legacy_mtf_behavior=False models
    #   3 - MoE_512_102xen_PolicyAll_LegacyFalse
    #   6 - MoE_512_102xen_PolicyRandom_LegacyFalse
    # shows that policy="random" gets worse with legacy_mtf_behavior=False, and
    # is similar to pre-cl/298510175
    #   4 - MoE_512_102xen_PolicyRandom
    #
    # gate_1 can become 0 due to Expert being out of capacity.
    #
    # gate_2 can become 0 due to
    #   second_expert_policy == 'random'
    # or "out of capacity" scenario.
    #
    # Here we renormalize regardless of cases above.
    denom = gate_1 + gate_2 + 1e-9
    gate_1 /= denom
    gate_2 /= denom

  # We reshape the mask as [X*S, E], and compute cumulative sums of
  # assignment indicators for each expert index e \in 0..E-1 independently.
  # First occurrence of assignment indicator is excluded, see exclusive=True
  # flag below.
  position_in_expert_1 = tf.cumsum(mask_1, exclusive=True, axis=1)

  # GS Tensor
  capacity = tf.cast(expert_capacity_dim, dtype=position_in_expert_1.dtype)

  # GE Tensor (reducing S out of GSE tensor mask_1)
  # density_1[:, e] represents assignment ratio (num assigned / total) to
  # expert e as top_1 expert without taking capacity into account.
  if legacy_mtf_behavior:
    density_denom = 1.0
  else:
    density_denom = tf.reduce_mean(
        importance, axis=(1))[:, tf.newaxis] + 1e-6
  density_1 = tf.reduce_mean(mask_1, axis=(1)) / density_denom
  # density_1_proxy[:, e] represents mean of raw_gates for expert e, including
  # those of examples not assigned to e with top_k.
  density_1_proxy = tf.reduce_mean(density_1_proxy, axis=1) / density_denom

  # The MoE paper (https://arxiv.org/pdf/1701.06538.pdf) uses an aux loss of
  # reduce_mean(density_1_proxy * density_1_proxy). Here we replace one of
  # the density_1_proxy with the discrete density_1 following
  # mesh_tensorflow/transformer/moe.py?rcl=283569345.
  aux_loss = tf.reduce_mean(density_1_proxy * density_1)  # element-wise
  aux_loss *= experts_dim * experts_dim  # const coefficient

  # Add the over capacity ratio for expert 1
  _CreateOverCapacityRatioSummary(mask_1, position_in_expert_1, capacity,
                                  'over_capacity_1_ratio')

  mask_1 *= tf.cast(tf.less(position_in_expert_1, capacity), dtype=mask_1.dtype)
  position_in_expert_1 = tf.einsum('GSE,GSE->GS', position_in_expert_1, mask_1)

  # How many examples in this sequence go to this expert
  mask_1_count = tf.einsum('GSE->GE', mask_1)
  # [batch, group] - mostly ones, but zeros where something didn't fit
  mask_1_flat = tf.einsum('GSE->GS', mask_1)

  if second_expert_policy == 'all' or second_expert_policy == 'sampling':
    pass
  elif second_expert_policy == 'random':
    # gate_2 is between 0 and 1, reminder:
    #
    #   raw_gates = tf.nn.softmax(logits)
    #   index_1 = tf.math.argmax(raw_gates, axis=-1, output_type=tf.int32)
    #   mask_1 = tf.one_hot(index_1, experts_dim, dtype=fprop_dtype)
    #   gate_1 = tf.einsum('GSE,GSE->GS', raw_gates, mask_1)
    #
    # E.g. if gate_2 exceeds second_expert_threshold, then we definitely
    # dispatch to second-best expert. Otherwise we dispatch with probability
    # proportional to (gate_2 / threshold).
    #
    sampled_2 = tf.less(
        _MaybeSplit(tf.random.uniform(gate_2.shape, dtype=gate_2.dtype)),
        (gate_2 / max(second_expert_threshold, 1e-9)))
    gate_2 *= tf.cast(sampled_2, gate_2.dtype)
    mask_2 *= tf.cast(tf.expand_dims(sampled_2, -1), mask_2.dtype)
  else:
    raise ValueError(second_expert_policy)

  position_in_expert_2 = tf.cumsum(
      mask_2, exclusive=True, axis=1) + tf.expand_dims(mask_1_count, 1)

  # Add the over capacity ratio for expert 2
  _CreateOverCapacityRatioSummary(mask_2, position_in_expert_2, capacity,
                                  'over_capacity_2_ratio')

  mask_2 *= tf.cast(tf.less(position_in_expert_2, capacity), mask_2.dtype)
  position_in_expert_2 = tf.einsum('GSE,GSE->GS', position_in_expert_2, mask_2)
  mask_2_flat = tf.reduce_sum(mask_2, axis=-1)

  # Equivalent non-einsum implementation:
  #
  # position_in_expert_2 *= mask_2
  # position_in_expert_2 = tf.reduce_sum(
  #     position_in_expert_2, axis=-1, name='position_in_expert_2')

  gate_1 *= mask_1_flat
  gate_2 *= mask_2_flat

  if not legacy_mtf_behavior:
    denom = gate_1 + gate_2
    # To avoid divide by 0.
    denom = tf.where(denom > 0, denom, tf.ones_like(denom))
    gate_1 /= denom
    gate_2 /= denom

  # GSC Tensor
  b = tf.one_hot(
      tf.cast(position_in_expert_1, dtype=tf.int32),
      expert_capacity_dim,
      dtype=fprop_dtype,
      name='one_hot_b_0')
  # GSE Tensor
  a = tf.expand_dims(gate_1 * mask_1_flat, -1) * tf.one_hot(
      index_1, experts_dim, dtype=fprop_dtype)
  # GSEC Tensor
  first_part_of_combine_tensor = tf.einsum(
      'GSE,GSC->GSEC', a, b, name='first_part_of_combine_tensor')

  # GSC Tensor
  b = tf.one_hot(
      tf.cast(position_in_expert_2, dtype=tf.int32),
      expert_capacity_dim,
      dtype=fprop_dtype,
      name='one_hot_b_1')
  # GSE Tensor
  a = tf.expand_dims(gate_2 * mask_2_flat, -1) * tf.one_hot(
      index_2, experts_dim, dtype=fprop_dtype)
  second_part_of_combine_tensor = tf.einsum(
      'GSE,GSC->GSEC', a, b, name='second_part_of_combine_tensor')

  # GSEC Tensor
  combine_tensor = (
      first_part_of_combine_tensor + second_part_of_combine_tensor)
  combine_tensor = _MaybeSplit(combine_tensor)

  # GSEC Tensor
  dispatch_tensor = tf.cast(tf.cast(combine_tensor, tf.bool), fprop_dtype)
  dispatch_tensor = _MaybeSplit(dispatch_tensor)

  # TODO(yonghui): compute and return per-group aux_loss.
  return aux_loss, combine_tensor, dispatch_tensor
Ejemplo n.º 24
0
    def __init__(self,
                 inference_graph,
                 subgraph_name=None,
                 checkpoint=None,
                 device_type="gpu",
                 tf_master="",
                 session_config=None,
                 clear_device_placement=False,
                 load_graph_def_from_inference_graph=True):
        """Constructor.

    Args:
      inference_graph: A saved InferenceGraph proto.
      subgraph_name: The default subgraph to use for Run().
      checkpoint: An optional checkpoint to load.
      device_type: Device type string. Either "cpu", "gpu", or "tpu".
      tf_master: The tf_master.
      session_config: A tf.SessionConfig to use. By default
        py_utils.SessionConfig() is used.
      clear_device_placement: If set, clears device field of loaded inference
        graph.
      load_graph_def_from_inference_graph: Whether to load a graph def.
        If False, assumes the names in the inference graph correspond to tensors
        in the current default graph.
    """
        assert device_type in ["cpu", "gpu", "tpu"]
        subgraph_name = subgraph_name or "default"
        if isinstance(inference_graph, str):
            tf.logging.info("Reading inference graph from %s.",
                            inference_graph)
            inference_graph = LoadInferenceGraph(inference_graph,
                                                 clear_device_placement)

        if not inference_graph.subgraphs:
            raise ValueError(
                "No subgraphs were defined in inference_graph. "
                "Check that subgraphs were defined and subgraph filters "
                "did not filter out all defined subgraphs.")

        self._inference_graph = inference_graph

        if subgraph_name not in inference_graph.subgraphs:
            raise ValueError(
                f"Subgraph {subgraph_name} not defined. Valid subgraphs: "
                f"{self.subgraphs}")
        subgraph = inference_graph.subgraphs[subgraph_name]
        self._fetches = subgraph.fetches
        self._feeds = subgraph.feeds

        self._default_subgraph_name = subgraph_name
        self._checkpoint = checkpoint
        self._device_type = device_type
        self._tf_master = tf_master
        self._session_config = session_config

        if load_graph_def_from_inference_graph:
            tf.logging.info(
                "Loading inference graph for prediction subgraph_name={}.".
                format(subgraph_name))
            self._graph = self._load_graph_from_inference_graph(
                inference_graph)
        else:
            self._graph = tf.get_default_graph()

        if device_type == "tpu":
            # If no tpu init op exists, create it here.
            try:
                self._graph.get_operation_by_name("tpu_init_op")
            except KeyError:
                with self._graph.as_default():
                    tf.group(tf.tpu.initialize_system(), name="tpu_init_op")

        self._graph.finalize()

        # Lock for creating new sessions.
        self._sess_lock = threading.Lock()
        self._cur_sess_id = 0
        self._create_new_session()