Esempio n. 1
0
    def _moments(self, inputs, reduction_axes, keep_dims):
        """Compute the mean and variance: it overrides the original _moments."""
        shard_mean, shard_variance = super(TpuBatchNormalization,
                                           self)._moments(inputs,
                                                          reduction_axes,
                                                          keep_dims=keep_dims)

        num_shards = tpu_function.get_tpu_context().number_of_shards or 1
        if num_shards <= 8:  # Skip cross_replica for 2x2 or smaller slices.
            num_shards_per_group = 1
        else:
            num_shards_per_group = max(8, num_shards // 8)
        logging.info(
            'TpuBatchNormalization with num_shards_per_group {}'.format(
                num_shards_per_group))
        if num_shards_per_group > 1:
            # Compute variance using: Var[X]= E[X^2] - E[X]^2.
            shard_square_of_mean = tf.math.square(shard_mean)
            shard_mean_of_square = shard_variance + shard_square_of_mean
            group_mean = self._cross_replica_average(shard_mean,
                                                     num_shards_per_group)
            group_mean_of_square = self._cross_replica_average(
                shard_mean_of_square, num_shards_per_group)
            group_variance = group_mean_of_square - tf.math.square(group_mean)
            return (group_mean, group_variance)
        else:
            return (shard_mean, shard_variance)
Esempio n. 2
0
    def _moments(self, inputs, reduction_axes, keep_dims):
        """Compute the mean and variance: it overrides the original _moments."""
        shard_mean, shard_variance = super(BatchNormalization,
                                           self)._moments(inputs,
                                                          reduction_axes,
                                                          keep_dims=keep_dims)

        num_shards = tpu_function.get_tpu_context().number_of_shards or 1
        if num_shards <= 8:  # Skip cross_replica for 2x2 or smaller slices.
            num_shards_per_group = 1
        else:
            num_shards_per_group = max(8, num_shards // 1)
        logging.info('BatchNormalization with num_shards_per_group %s',
                     num_shards_per_group)
        if num_shards_per_group > 1:
            # Each group has multiple replicas: here we compute group mean/variance by
            # aggregating per-replica mean/variance.
            group_mean = self._cross_replica_average(shard_mean,
                                                     num_shards_per_group)
            group_variance = self._cross_replica_average(
                shard_variance, num_shards_per_group)

            # Group variance needs to also include the difference between shard_mean
            # and group_mean.
            mean_distance = tf.square(group_mean - shard_mean)
            group_variance += self._cross_replica_average(
                mean_distance, num_shards_per_group)
            return (group_mean, group_variance)
        else:
            return (shard_mean, shard_variance)
Esempio n. 3
0
def cross_replica_concat(tensor):
    """A cross-replica concatenation of a single Tensor across TPU cores.

  Input tensor is assumed to have batch dimension as the first dimension. The
  concatenation is done along the batch dimension.

  Args:
    tensor: Input Tensor which should be concatenated across TPU cores.

  Returns:
    The concatenated Tensor with batch dimension multiplied by the number of
      TPU cores.
  """
    num_tpu_replicas = tpu_function.get_tpu_context().number_of_shards

    if num_tpu_replicas is not None:
        # Scattered tensor has shape [num_replicas, local_batch_size, ...]
        scattered_tensor = tf.scatter_nd(indices=[[local_tpu_replica_id()]],
                                         updates=[tensor],
                                         shape=[num_tpu_replicas] +
                                         tensor.shape.as_list())
        reduced_tensor = tf.tpu.cross_replica_sum(scattered_tensor)
        # Returned tensor has shape [num_replicas * local_batch_size, ...]
        return tf.reshape(reduced_tensor,
                          [-1] + scattered_tensor.shape.as_list()[2:])
    else:
        # This is a no op if not running on TPU
        return tensor
Esempio n. 4
0
 def _Moments(self, inputs, group_size):
   """Computes mean and variance over N,H,W dimensions in inputs."""
   counts, mean_ss, variance_ss, _, = tf.nn.sufficient_statistics(
       inputs, axes=[0, 1, 2], keepdims=False)
   self.accumulators.counts.Update(counts)
   self.accumulators.mean_ss.Update(mean_ss)
   self.accumulators.variance_ss.Update(variance_ss)
   # Distributed batch norm that computes sufficient statistics from group_size
   # replicas. This is useful when batch_size_per_replica is too small to
   # compute reliable sufficient statistics.
   if py_utils.use_tpu() and group_size > 1:
     group_assignment = None
     num_shards = tpu_function.get_tpu_context().number_of_shards
     if num_shards is not None:
       if num_shards < group_size:
         raise ValueError('TPU shards={} less than bn_gropu_size={}.'.format(
             num_shards, group_size))
       if num_shards % group_size:
         raise ValueError(
             'TPU shards={} not divisible by bn_group_size={}.'.format(
                 num_shards, group_size))
       num_groups = num_shards // group_size
       group_assignment = []
       for g in range(num_groups):
         replica_ids = [g * group_size + i for i in range(group_size)]
         group_assignment.append(replica_ids)
       counts *= group_size
     mean_ss = tf.tpu.cross_replica_sum(mean_ss, group_assignment)
     variance_ss = tf.tpu.cross_replica_sum(variance_ss, group_assignment)
   # At each micro-step, batch_mean and batch_variance are computed
   # to normalize inputs. But they are not used to update moving_mean and
   # moving_variance variables until the last micro batch.
   mean, variance = tf.nn.normalize_moments(counts, mean_ss, variance_ss, None)
   return mean, variance
Esempio n. 5
0
def is_tpu_replicated():
  is_tpu_strategy = (tf.distribute.has_strategy() and
                     tf.distribute.get_replica_context() and
                     isinstance(tf.distribute.get_strategy(),
                                tf.distribute.experimental.TPUStrategy))
  num_shards = tpu_function.get_tpu_context().number_of_shards
  return is_tpu_strategy or num_shards is not None
Esempio n. 6
0
def _create_default_group_assignment():
    num_shards = tpu_function.get_tpu_context().number_of_shards
    if num_shards is None:
        logging.warning(
            "cross_replica_sum should be used within a tpu_shard_context, but "
            "got unset number_of_shards. Assuming 1.")
        num_shards = 1
    group_assignment = [list(range(num_shards))]
    return group_assignment
Esempio n. 7
0
def _create_default_group_assignment():
  num_shards = tpu_function.get_tpu_context().number_of_shards
  if num_shards is None:
    logging.warning(
        "cross_replica_sum should be used within a tpu_shard_context, but "
        "got unset number_of_shards. Assuming 1.")
    num_shards = 1
  group_assignment = [list(range(num_shards))]
  return group_assignment
Esempio n. 8
0
def local_tpu_replica_id():
  """Returns the index of the current TPU replica."""
  num_tpu_replicas = tpu_function.get_tpu_context().number_of_shards
  if num_tpu_replicas is not None:
    # Need tf.control_dependencies(None) in order to make sure this is run
    # on CPU (not TPU)
    with tf.control_dependencies(None):
      return tpu_ops.tpu_replicated_input(
          list(range(num_tpu_replicas)), name='local_replica_id')
  else:
    # The non-TPU case.
    return 0
Esempio n. 9
0
    def compute_gradients(self, loss, var_list=None, **kwargs):
        """Compute gradients of "loss" for the variables in "var_list".

    This simply wraps `compute_gradients()` from the real optimizer. The
    gradients will be aggregated in `apply_gradients()` so that user can
    modify the gradients like clipping with per replica global norm if needed.
    The global norm with aggregated gradients can be bad as one replica's huge
    gradients can hurt the gradients from other replicas.

    When the CrossShardOptimizer is constructed with
    `reduction == losses.Reduction.MEAN` (default), this function scales the
    loss by `1.0 / num_shards` before computing the gradients. Assuming the
    optimizer uses the default implementation of `compute_gradients()`, the
    gradients of the scaled loss are scaled by `1.0 / num_shards` compared to
    the gradients of the original loss. This scaling factor is important because
    `apply_gradients()` sums gradients across shards, rather than averaging
    them. However, the scaling factor must be taken into account when clipping
    the norm of the gradients or performing other postprocessing.

    Args:
      loss: A Tensor containing the value to minimize.
      var_list: Optional list or tuple of `tf.Variable` to update to minimize
        `loss`.  Defaults to the list of variables collected in the graph
        under the key `GraphKey.TRAINABLE_VARIABLES`.
      **kwargs: Keyword arguments for compute_gradients().

    Returns:
      A list of (gradient, variable) pairs.

    Raises:
      ValueError: If not within a tpu_shard_context or group_assignment is
        invalid.
    """
        num_shards = tpu_function.get_tpu_context().number_of_shards
        if num_shards is None:
            logging.warning(
                "CrossShardOptimizer should be used within a tpu_shard_context, but "
                "got unset number_of_shards. Assuming 1.")
            num_shards = 1

        subgroup_size = self._verify_and_get_subgroup_size(
            self._group_assignment, num_shards)

        if num_shards > 1 and self._reduction == losses.Reduction.MEAN:
            if self._group_assignment:
                scale = 1.0 / subgroup_size
            else:
                scale = 1.0 / num_shards
            loss *= scale

        return self._opt.compute_gradients(loss, var_list=var_list, **kwargs)
Esempio n. 10
0
 def _cross_replica_average(self, t, num_shards_per_group):
   """Calculates the average value of input tensor across TPU replicas."""
   num_shards = tpu_function.get_tpu_context().number_of_shards
   group_assignment = None
   if num_shards_per_group > 1:
     if num_shards % num_shards_per_group != 0:
       raise ValueError('num_shards: %d mod shards_per_group: %d, should be 0'
                        % (num_shards, num_shards_per_group))
     num_groups = num_shards // num_shards_per_group
     group_assignment = [[
         x for x in range(num_shards) if x // num_shards_per_group == y
     ] for y in range(num_groups)]
   return tf.tpu.cross_replica_sum(t, group_assignment) / tf.cast(
       num_shards_per_group, t.dtype)
Esempio n. 11
0
def cross_replica_mean(inputs, group_size=None):
    """Calculates the average value of inputs tensor across TPU replicas."""
    num_replicas = tpu_function.get_tpu_context().number_of_shards
    if not group_size:
        group_size = num_replicas
    if group_size == 1:
        return inputs
    if group_size != num_replicas:
        group_assignment = []
        assert num_replicas % group_size == 0
        for g in range(num_replicas // group_size):
            replica_ids = [g * group_size + i for i in range(group_size)]
            group_assignment.append(replica_ids)
    else:
        group_assignment = None
    return tf.compat.v1.tpu.cross_replica_sum(
        inputs, group_assignment) / tf.cast(group_size, inputs.dtype)
Esempio n. 12
0
def make_train_op(optimizer, loss, trainable_variables, global_step,
                  grad_clip_norm):
  num_cores = tpu_function.get_tpu_context().number_of_shards

  # compute scaled gradient
  grads_and_vars = optimizer.compute_gradients(
      loss / float(num_cores), var_list=trainable_variables)

  # clip gradient
  clipped_grads, gnorm = tf.clip_by_global_norm(
      [g for (g, _) in grads_and_vars], grad_clip_norm / float(num_cores))
  grads_and_vars = [(g, v) for g, (_, v) in zip(clipped_grads, grads_and_vars)]

  # optimize
  optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
  train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)

  return train_op, gnorm
Esempio n. 13
0
  def _moments(self, inputs, reduction_axes, keep_dims):
    """Compute the mean and variance: it overrides the original _moments."""
    shard_mean, shard_variance = super(BatchNormalization, self)._moments(
        inputs, reduction_axes, keep_dims=keep_dims)

    num_shards = tpu_function.get_tpu_context().number_of_shards
    if num_shards and num_shards > 1:
      # Each group has multiple replicas: here we compute group mean/variance by
      # aggregating per-replica mean/variance.
      group_mean = self._cross_replica_average(shard_mean)
      group_variance = self._cross_replica_average(shard_variance)

      # Group variance needs to also include the difference between shard_mean
      # and group_mean.
      mean_distance = tf.square(group_mean - shard_mean)
      group_variance += self._cross_replica_average(mean_distance)
      return (group_mean, group_variance)
    else:
      return (shard_mean, shard_variance)
Esempio n. 14
0
  def _moments(self, inputs, reduction_axes, keep_dims):
    """Compute the mean and variance: it overrides the original _moments."""
    shard_mean, shard_variance = super()._moments(
        inputs, reduction_axes, keep_dims=keep_dims)

    num_shards = tpu_function.get_tpu_context().number_of_shards or 1
    num_shards_per_group = min(32, num_shards)  # aggregate up to 32 cores.
    logging.info('TpuBatchNormalization with num_shards_per_group {}'.format(
        num_shards_per_group))
    if num_shards_per_group > 1:
      # Compute variance using: Var[X]= E[X^2] - E[X]^2.
      shard_square_of_mean = tf.math.square(shard_mean)
      shard_mean_of_square = shard_variance + shard_square_of_mean
      group_mean = cross_replica_mean(shard_mean, num_shards_per_group)
      group_mean_of_square = cross_replica_mean(
          shard_mean_of_square, num_shards_per_group)
      group_variance = group_mean_of_square - tf.math.square(group_mean)
      return (group_mean, group_variance)
    else:
      return (shard_mean, shard_variance)
Esempio n. 15
0
  def compute_gradients(self, loss, var_list=None, **kwargs):
    """Compute gradients of "loss" for the variables in "var_list".

    This simply wraps the compute_gradients() from the real optimizer. The
    gradients will be aggregated in the apply_gradients() so that user can
    modify the gradients like clipping with per replica global norm if needed.
    The global norm with aggregated gradients can be bad as one replica's huge
    gradients can hurt the gradients from other replicas.

    Args:
      loss: A Tensor containing the value to minimize.
      var_list: Optional list or tuple of `tf.Variable` to update to minimize
        `loss`.  Defaults to the list of variables collected in the graph
        under the key `GraphKey.TRAINABLE_VARIABLES`.
      **kwargs: Keyword arguments for compute_gradients().

    Returns:
      A list of (gradient, variable) pairs.

    Raises:
      ValueError: If not within a tpu_shard_context or group_assignment is
        invalid.
    """
    num_shards = tpu_function.get_tpu_context().number_of_shards
    if num_shards is None:
      logging.warning(
          "CrossShardOptimizer should be used within a tpu_shard_context, but "
          "got unset number_of_shards. Assuming 1.")
      num_shards = 1

    subgroup_size = self._verify_and_get_subgroup_size(self._group_assignment,
                                                       num_shards)

    if num_shards > 1 and self._reduction == losses.Reduction.MEAN:
      if self._group_assignment:
        scale = 1.0 / subgroup_size
      else:
        scale = 1.0 / num_shards
      loss *= scale

    return self._opt.compute_gradients(loss, var_list=var_list, **kwargs)
Esempio n. 16
0
def get_num_replicas():
  """Returns the number of replicas.

  If not operating in a supported replicated context this function will return
  1.
  """

  tf_replicator = get_tf_replicator()

  if tf_replicator:
    return tf_replicator.num_replicas_in_sync
  elif tf.distribute.has_strategy():
    return tf.distribute.get_strategy().num_replicas_in_sync
  else:
    # I'm assuming replicas and shards are always equal until someone tells me
    # different.
    num_replicas = tpu_function.get_tpu_context().number_of_shards
    if num_replicas:
      return num_replicas
    else:
      return 1
Esempio n. 17
0
def _is_running_on_cpu():
  """Returns True if the current context is CPU model."""
  return tpu_function.get_tpu_context().number_of_shards is None
Esempio n. 18
0
 def _cross_replica_average(self, t):
   """Calculates the average value of input tensor across TPU replicas."""
   num_shards = tpu_function.get_tpu_context().number_of_shards
   return tf.tpu.cross_replica_sum(t) / tf.cast(num_shards, t.dtype)
Esempio n. 19
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host))
        tf.logging.info('num_devices_per_split {}'.format(
            cluster.num_devices_per_split))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            input_ops_list = []
            queues = []
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            tpu_emb_input_keys = (list(
                tpu_embedding.feature_to_config_dict.keys())
                                  if tpu_embedding is not None else [])
            tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)

            batch = None
            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    if 'bucket_keys' in batch:
                        # Hack: bucket_keys are not needed on TPU.
                        del batch['bucket_keys']
                    tf.logging.info('host_device: %s, batch: %r', host_device,
                                    batch)

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {} for _ in range(tpu_embedding.num_cores_per_host)
                        ]
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for key in tpu_emb_input_keys:
                            feat = batch[key]
                            tpu_emb_feat_splitted = tf.split(
                                feat, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_emb_feat_splitted):
                                # Dense to sparse. Note the assumption of a padding id.
                                sample_indices = tf.where(
                                    tf.not_equal(split, -1))
                                embedding_indices = tf.gather_nd(
                                    split, sample_indices)
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    embedding_indices, sample_indices)
                                enqueue_dict_per_core[core][key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    for k, x in batch.FlattenItems():
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                    shapes = batch.Transform(lambda x: x.shape).Flatten()
                    dtypes = batch.Transform(lambda x: x.dtype).Flatten()
                    tf.logging.info('host_device: %s infeed shapes: %r',
                                    host_device, shapes)
                    tf.logging.info('host_device: %s infeed dtypes: %r',
                                    host_device, dtypes)
                    q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                             tuple_shapes=shapes)
                    queues.append(q)
                    assert shards is not None
                    q.set_number_of_shards(shards)

                    if p.use_per_host_infeed:

                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def TPUOrdinalFunction(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=TPUOrdinalFunction)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallelism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        # For executor-driven multiple programs, we need more fine-grained
        # access rather than using a single global graph collection.
        self.tpu_infeed_op = tpu_infeed_op

        with tf.device(tf.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return batch.Pack(tensors)
Esempio n. 20
0
def num_tpu_replicas():
    return tpu_function.get_tpu_context().number_of_shards
Esempio n. 21
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info(
            'CreateTPUFeeds num_splits_per_client={} '
            'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'.
            format(cluster.num_splits_per_client,
                   cluster.num_devices_per_split, num_tpu_hosts,
                   p.use_per_host_infeed))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            tf.logging.info('shards {}'.format(shards))

            input_ops_list = []
            queues = []
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            if num_tpu_hosts > 1 and tpu_embedding is not None:
                if not p.use_per_host_infeed:
                    tf.logging.fatal(
                        'TPU Embedding must be used with per_host_infeed with multiple '
                        'TPU host topologies.')
            tpu_emb_input_keys = (list(
                tpu_embedding.feature_to_config_dict.keys())
                                  if tpu_embedding is not None else [])
            tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys)

            batch = None
            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    if isinstance(batch, py_utils.NestedMap):
                        # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU.
                        # Note that when MultiTaskData is used, bucket_keys will be at the
                        # second level of the dictionary.
                        batch = batch.FilterKeyVal(
                            lambda k, _: not k.endswith('bucket_keys'))
                    tf.logging.info('host_device: %s, batch: %r', host_device,
                                    batch)

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {} for _ in range(tpu_embedding.num_cores_per_host)
                        ]
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for key in tpu_emb_input_keys:
                            feat = batch[key]
                            tpu_emb_feat_splitted = tf.split(
                                feat, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_emb_feat_splitted):
                                # Dense to sparse. Note the assumption of a padding id.
                                sample_indices = tf.where(
                                    tf.not_equal(split, -1))
                                embedding_indices = tf.gather_nd(
                                    split, sample_indices)
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    embedding_indices, sample_indices)
                                enqueue_dict_per_core[core][key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    for k, x in batch.FlattenItems():
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                    shapes = batch.Transform(lambda x: x.shape).Flatten()
                    dtypes = batch.Transform(lambda x: x.dtype).Flatten()
                    tf.logging.info('host_device: %s infeed shapes: %r',
                                    host_device, shapes)
                    tf.logging.info('host_device: %s infeed dtypes: %r',
                                    host_device, dtypes)
                    if p.use_partitioned_infeed_queue:
                        device_assignment = py_utils.GetTpuDeviceAssignment()

                        host_device = device_assignment.host_device(
                            replica=0, job=tf.flags.FLAGS.tf_master)
                        host_id = int(
                            host_device.split('/task:')[1].split('/device:')
                            [0])
                        tf.logging.info('host_id: {} host_device: {}'.format(
                            host_id, host_device))
                        q = tpu_feed._PartitionedInfeedQueue(  # pylint: disable=protected-access
                            number_of_tuple_elements=len(dtypes),
                            device_assignment=device_assignment,
                            host_id=host_id,
                            input_partition_dims=[[p.num_partitions, 1]
                                                  for _ in dtypes],
                            tuple_types=dtypes,
                            tuple_shapes=shapes)
                    else:
                        q = tpu_feed.InfeedQueue(tuple_types=dtypes,
                                                 tuple_shapes=shapes)
                        assert shards is not None
                        q.set_number_of_shards(shards)

                    queues.append(q)
                    tf.logging.info('q=%r', q)

                    if p.use_partitioned_infeed_queue:
                        input_ops = q.generate_enqueue_ops([batch.Flatten()])
                    elif p.use_per_host_infeed:
                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def TPUOrdinalFunction(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=TPUOrdinalFunction)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            batch.Flatten(),
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallelism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        self._tpu_infeed_op = tpu_infeed_op

        with tf.device(tf.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return batch.Pack(tensors)
Esempio n. 22
0
  def body_wrapper(*inputs):
    """Wrapper around `body` that handles infeed queues and control deps."""
    inputs = list(inputs)

    # Discards the dummy output added for arity-0 loops.
    if input_arity == 0:
      inputs = []

    # Runs `body` with the dequeue_ops appended.
    if infeed_queue:
      number_of_shards = tpu_function.get_tpu_context().number_of_shards
      if number_of_shards is None:
        raise ValueError("Can't build training loop with infeed when there is "
                         "no tpu_shard_context. Are you building a loop or "
                         "graph directly rather than from inside tpu.rewrite, "
                         "tpu.batch_parallel, tpu.shard, or tpu.replicate?")
      infeed_queue.set_number_of_shards(number_of_shards)
      dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()]
    else:
      dequeue_ops = []
    outputs = body(*(inputs + dequeue_ops))

    # If the computation only returned one value, make it a tuple.
    if not isinstance(outputs, (list, tuple)):
      outputs = (outputs,)

    outputs = [
        o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
        for o in outputs
    ]

    # Separates the returned Operations and Tensors.
    output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
    output_tensors = [o for o in outputs
                      if not isinstance(o, ops.Operation)]

    if outputs != output_tensors + output_operations:
      raise ValueError(
          "TPU training loop body must return zero or more Tensor values "
          "followed by zero or more Operations.")

    output_types = [op.dtype for op in output_tensors]
    if input_types != output_types:
      raise TypeError(
          "Mismatch between input types and output types for training loop "
          "body: {} vs {}".format(input_types, output_types))

    # Add the dequeue operations to output_operations to ensure they are run
    # by the loop, even if the programmer's loop body does not use them.
    output_operations += dequeue_ops

    # Add a dummy output, if needed.
    if not output_tensors:
      output_tensors = array_ops.constant(0)

    if output_operations:
      # TODO(phawkins): in principle this is too restrictive since it serializes
      # the training loop steps. In practice it does not matter since this loop
      # will be compiled by XLA.
      output_tensors = control_flow_ops.tuple(output_tensors,
                                              control_inputs=output_operations)

    if tensor_tracer.TensorTracer.is_enabled():
      num_replicas = tpu_function.get_tpu_context().number_of_shards
      if num_replicas is None:
        num_replicas = 1
      tt = tensor_tracer.TensorTracer()
      output_tensors = tt.trace_tpu(ops.get_default_graph(),
                                    output_tensors, None,
                                    num_replicas)
    return output_tensors
Esempio n. 23
0
def _is_running_on_cpu():
    """Returns True if the current context is CPU model."""
    return tpu_function.get_tpu_context().number_of_shards is None
Esempio n. 24
0
def num_tpu_shards():
    """Get the number of TPU shards."""
    return tpu_function.get_tpu_context().number_of_shards
Esempio n. 25
0
def standardize_batch(inputs,
                      is_training,
                      offset=None,
                      scale=None,
                      decay=0.999,
                      epsilon=1e-3,
                      data_format='NHWC',
                      use_moving_averages=True,
                      use_cross_replica_mean=None):
    """Adds TPU-enabled batch normalization layer.

  Details on Batch Normalization can be found in 'Batch Normalization:
  Accelerating Deep Network Training by Reducing Internal Covariate Shift',
  Ioffe S. and Szegedy C. 2015 [http://arxiv.org/abs/1502.03167].

  Note #1: This method computes the batch statistic across all TPU replicas,
  thus simulating the true batch norm in the distributed setting. If one wants
  to avoid the cross-replica communication set use_cross_replica_mean=False.

  Note #2: When is_training is True the moving_mean and moving_variance need
  to be updated in each training step. By default, the update_ops are placed
  in `tf.GraphKeys.UPDATE_OPS` and they need to be added as a dependency to
  the `train_op`. For example:

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    if update_ops:
      updates = tf.group(*update_ops)
      total_loss = control_flow_ops.with_dependencies([updates], total_loss)

  Note #3: Reasonable values for `decay` are close to 1.0, typically in the
  multiple-nines range: 0.999, 0.99, 0.9, etc. Lower the `decay` value (trying
  `decay`=0.9) if model experiences reasonably good training performance but
  poor validation and/or test performance.

  Args:
    inputs: A tensor with 2 or 4 dimensions, where the first dimension is
      `batch_size`. The normalization is over all but the last dimension if
      `data_format` is `NHWC`, and the second dimension if `data_format` is
      `NCHW`.
    is_training: Whether or not the layer is in training mode. In training
      mode it would accumulate the statistics of the moments into the
      `moving_mean` and `moving_variance` using an exponential moving average
      with the given `decay`. When is_training=False, these variables are not
      updated, and the precomputed values are used verbatim.
    offset: An offset `Tensor`, often denoted `beta` in equations, or
      None. If present, will be added to the normalized tensor.
    scale: A scale `Tensor`, often denoted `gamma` in equations, or
      `None`. If present, the scale is applied to the normalized tensor.
    decay: Decay for the moving averages. See notes above for reasonable
      values.
    epsilon: Small float added to variance to avoid dividing by zero.
    data_format: Input data format. NHWC or NCHW.
    use_moving_averages: If True keep moving averages of mean and variance that
      are used during inference. Otherwise use accumlators.
    use_cross_replica_mean: If True add operations to do computes batch norm
      statistics across all TPU cores. These ops are not compatible with other
      platforms. The default (None) will only add the operations if running
      on TPU.

  Returns:
    The normalized tensor with the same type and shape as `inputs`.
  """
    if data_format not in {'NCHW', 'NHWC'}:
        raise ValueError(
            'Invalid data_format {}. Allowed: NCHW, NHWC.'.format(data_format))
    if use_cross_replica_mean is None:
        # Default to global batch norm only on TPUs.
        use_cross_replica_mean = (
            tpu_function.get_tpu_context().number_of_shards is not None)
        logging.debug('Automatically determined use_cross_replica_mean=%s.',
                      use_cross_replica_mean)

    inputs = tf.convert_to_tensor(value=inputs)
    inputs_dtype = inputs.dtype
    inputs_shape = inputs.get_shape()

    num_channels = tf.compat.dimension_value(inputs.shape[-1])
    if num_channels is None:
        raise ValueError('`C` dimension must be known but is None')

    inputs_rank = inputs_shape.ndims
    if inputs_rank is None:
        raise ValueError('Inputs %s has undefined rank' % inputs.name)
    elif inputs_rank not in [2, 4]:
        raise ValueError('Inputs %s has unsupported rank.'
                         ' Expected 2 or 4 but got %d' %
                         (inputs.name, inputs_rank))
    # Bring 2-D inputs into 4-D format.
    if inputs_rank == 2:
        new_shape = [-1, 1, 1, num_channels]
        if data_format == 'NCHW':
            new_shape = [-1, num_channels, 1, 1]
        inputs = tf.reshape(inputs, new_shape)
        if offset is not None:
            offset = tf.reshape(offset, new_shape)
        if scale is not None:
            scale = tf.reshape(scale, new_shape)

    # Execute a distributed batch normalization
    axis = 1 if data_format == 'NCHW' else 3
    inputs = tf.cast(inputs, tf.float32)
    reduction_axes = [i for i in range(4) if i != axis]
    if use_cross_replica_mean:
        mean, variance = cross_replica_moments(inputs, reduction_axes)
    else:
        counts, mean_ss, variance_ss, _ = tf.nn.sufficient_statistics(
            inputs, reduction_axes, keepdims=False)
        mean, variance = tf.nn.normalize_moments(counts,
                                                 mean_ss,
                                                 variance_ss,
                                                 shift=None)

    if use_moving_averages:
        mean, variance = moving_moments_for_inference(mean=mean,
                                                      variance=variance,
                                                      is_training=is_training,
                                                      decay=decay)
    else:
        mean, variance = accumulated_moments_for_inference(
            mean=mean, variance=variance, is_training=is_training)

    outputs = tf.nn.batch_normalization(inputs,
                                        mean=mean,
                                        variance=variance,
                                        offset=offset,
                                        scale=scale,
                                        variance_epsilon=epsilon)
    outputs = tf.cast(outputs, inputs_dtype)
    # Bring 2-D inputs back into 2-D format.
    if inputs_rank == 2:
        outputs = tf.reshape(outputs, [-1] + inputs_shape[1:].as_list())
    outputs.set_shape(inputs_shape)
    return outputs
Esempio n. 26
0
    def body_wrapper(*inputs):
        """Wrapper around `body` that handles infeed queues and control deps."""
        inputs = list(inputs)

        # Discards the dummy output added for arity-0 loops.
        if input_arity == 0:
            inputs = []

        # Runs `body` with the dequeue_ops appended.
        if infeed_queue:
            number_of_shards = tpu_function.get_tpu_context().number_of_shards
            if number_of_shards is None:
                raise ValueError(
                    "Can't build training loop with infeed when there is "
                    "no tpu_shard_context. Are you building a loop or "
                    "graph directly rather than from inside tpu.rewrite, "
                    "tpu.batch_parallel, tpu.shard, or tpu.replicate?")
            infeed_queue.set_number_of_shards(number_of_shards)
            dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()]
        else:
            dequeue_ops = []
        outputs = body(*(inputs + dequeue_ops))

        # If the computation only returned one value, make it a tuple.
        if not isinstance(outputs, (list, tuple)):
            outputs = (outputs, )

        outputs = [
            o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
            for o in outputs
        ]

        # Separates the returned Operations and Tensors.
        output_operations = [
            o for o in outputs if isinstance(o, ops.Operation)
        ]
        output_tensors = [
            o for o in outputs if not isinstance(o, ops.Operation)
        ]

        if outputs != output_tensors + output_operations:
            raise ValueError(
                "TPU training loop body must return zero or more Tensor values "
                "followed by zero or more Operations.")

        output_types = [op.dtype for op in output_tensors]
        if input_types != output_types:
            raise TypeError(
                "Mismatch between input types and output types for training loop "
                "body: {} vs {}".format(input_types, output_types))

        # Add the dequeue operations to output_operations to ensure they are run
        # by the loop, even if the programmer's loop body does not use them.
        output_operations += dequeue_ops

        # Add a dummy output, if needed.
        if not output_tensors:
            output_tensors = array_ops.constant(0)

        if output_operations:
            # TODO(phawkins): in principle this is too restrictive since it serializes
            # the training loop steps. In practice it does not matter since this loop
            # will be compiled by XLA.
            output_tensors = control_flow_ops.tuple(
                output_tensors, control_inputs=output_operations)

        if tensor_tracer.TensorTracer.is_enabled():
            num_replicas = tpu_function.get_tpu_context().number_of_shards
            if num_replicas is None:
                num_replicas = 1
            tt = tensor_tracer.TensorTracer()
            output_tensors = tt.trace_tpu(ops.get_default_graph(),
                                          output_tensors, None, num_replicas)
        return output_tensors