Exemple #1
0
def _rewire_summaries():
    """Rewire Tensorflow summaries to be no-ops when running on TPU.

  Summaries are not currently supported on TPU.

  Yields:
    Context where summary functions are rewired to be no-ops when on TPU.
  """

    if tpu_function.get_tpu_context().number_of_shards == 0:
        yield
        return

    tf.logging.log_first_n(
        tf.logging.WARN,
        "Converting summaries to no-ops on TPU since they are not supported.",
        1)
    old_summary_audio = summary.audio
    old_summary_histogram = summary.histogram
    old_summary_image = summary.image
    old_summary_scalar = summary.scalar
    old_summary_tensor_summary = summary.tensor_summary
    old_summary_text = summary.text

    def _no_op(*args, **kwargs):
        del args, kwargs  # Unused
        return tf.constant("", name="summary_no_op")

    # Monkey-patch global attributes.
    summary.audio = _no_op
    summary.histogram = _no_op
    summary.image = _no_op
    summary.scalar = _no_op
    summary.tensor_summary = _no_op
    summary.text = _no_op

    tf.summary.audio = _no_op
    tf.summary.histogram = _no_op
    tf.summary.image = _no_op
    tf.summary.scalar = _no_op
    tf.summary.tensor_summary = _no_op
    tf.summary.text = _no_op

    try:
        yield
    finally:
        # Revert monkey-patches.
        summary.audio = old_summary_audio
        summary.histogram = old_summary_histogram
        summary.image = old_summary_image
        summary.scalar = old_summary_scalar
        summary.tensor_summary = old_summary_tensor_summary
        summary.text = old_summary_text

        tf.summary.audio = old_summary_audio
        tf.summary.histogram = old_summary_histogram
        tf.summary.image = old_summary_image
        tf.summary.scalar = old_summary_scalar
        tf.summary.tensor_summary = old_summary_tensor_summary
        tf.summary.text = old_summary_text
Exemple #2
0
    def __init__(self, scope=None, skip_summary=False):
        """Initializes a `_ScopedSummary`.

    Args:
      scope: String scope name.
      skip_summary: Whether to record summary ops.

    Returns:
      A `_ScopedSummary` instance.
    """

        if tpu_function.get_tpu_context().number_of_shards:
            tf.logging.log_first_n(
                tf.logging.WARN,
                "Scoped summaries will be skipped since they do not support TPU",
                1)
            skip_summary = True

        self._scope = scope
        self._additional_scope = None
        self._skip_summary = skip_summary
        self._summary_ops = []
        self._actual_summary_scalar_fn = tf.summary.scalar
        self._actual_summary_image_fn = tf.summary.image
        self._actual_summary_histogram_fn = tf.summary.histogram
        self._actual_summary_audio_fn = tf.summary.audio
Exemple #3
0
    def compute_gradients(self, loss, var_list=None, **kwargs):
        """Compute gradients of "loss" for the variables in "var_list".

    This simply wraps the compute_gradients() from the real optimizer. The
    gradients will be aggregated in the apply_gradients() so that user can
    modify the gradients like clipping with per replica global norm if needed.
    The global norm with aggregated gradients can be bad as one replica's huge
    gradients can hurt the gradients from other replicas.

    Args:
      loss: A Tensor containing the value to minimize.
      var_list: Optional list or tuple of `tf.Variable` to update to minimize
        `loss`.  Defaults to the list of variables collected in the graph
        under the key `GraphKey.TRAINABLE_VARIABLES`.
      **kwargs: Keyword arguments for compute_gradients().

    Returns:
      A list of (gradient, variable) pairs.

    Raises:
      ValueError: If not within a tpu_shard_context.
    """
        num_shards = tpu_function.get_tpu_context().number_of_shards
        if num_shards is None:
            raise ValueError("CrossShardOptimizer must be used within a "
                             "tpu_shard_context.")
        if num_shards > 1 and self._reduction == losses.Reduction.MEAN:
            scale = 1.0 / num_shards
            loss *= scale
        return self._opt.compute_gradients(loss, var_list=var_list, **kwargs)
Exemple #4
0
 def _Moments(self, inputs, group_size):
   """Computes mean and variance over N,H,W dimensions in inputs."""
   counts, mean_ss, variance_ss, _, = tf.nn.sufficient_statistics(
       inputs, axes=[0, 1, 2], keep_dims=False)
   self.accumulators.counts.Update(counts)
   self.accumulators.mean_ss.Update(mean_ss)
   self.accumulators.variance_ss.Update(variance_ss)
   # Distributed batch norm that computes sufficient statistics from group_size
   # replicas. This is useful when batch_size_per_replica is too small to
   # compute reliable sufficient statistics.
   if py_utils.use_tpu() and group_size > 1:
     group_assignment = None
     num_shards = tpu_function.get_tpu_context().number_of_shards
     if num_shards is not None:
       if num_shards < group_size:
         raise ValueError('TPU shards={} less than bn_gropu_size={}.'.format(
             num_shards, group_size))
       if num_shards % group_size:
         raise ValueError(
             'TPU shards={} not divisible by bn_group_size={}.'.format(
                 num_shards, group_size))
       num_groups = num_shards // group_size
       group_assignment = []
       for g in range(num_groups):
         replica_ids = [g * group_size + i for i in range(group_size)]
         group_assignment.append(replica_ids)
       counts *= group_size
     mean_ss = tf.contrib.tpu.cross_replica_sum(mean_ss, group_assignment)
     variance_ss = tf.contrib.tpu.cross_replica_sum(variance_ss,
                                                    group_assignment)
   # At each micro-step, batch_mean and batch_variance are computed
   # to normalize inputs. But they are not used to update moving_mean and
   # moving_variance variables until the last micro batch.
   mean, variance = tf.nn.normalize_moments(counts, mean_ss, variance_ss, None)
   return mean, variance
Exemple #5
0
def is_tpu_replicated():
    is_tpu_strategy = (tf.distribute.has_strategy()
                       and tf.distribute.get_replica_context()
                       and isinstance(tf.distribute.get_strategy(),
                                      tf.distribute.experimental.TPUStrategy))
    num_shards = tpu_function.get_tpu_context().number_of_shards
    return is_tpu_strategy or num_shards is not None
Exemple #6
0
    def _moments(self, inputs, reduction_axes, keep_dims):
        """Compute the mean and variance: it overrides the original _moments."""
        shard_mean, shard_variance = super(TpuBatchNormalization, self)._moments(
            inputs, reduction_axes, keep_dims=keep_dims)

        num_shards = tpu_function.get_tpu_context().number_of_shards or 1
        if num_shards <= 8:  # Skip cross_replica for 2x2 or smaller slices.
            num_shards_per_group = 1
        else:
            num_shards_per_group = max(8, num_shards // 4)
        tf.logging.info('TpuBatchNormalization with num_shards_per_group %s',
                        num_shards_per_group)
        if num_shards_per_group > 1:
            # Each group has multiple replicas: here we compute group mean/variance by
            # aggregating per-replica mean/variance.
            group_mean = self._cross_replica_average(shard_mean, num_shards_per_group)
            group_variance = self._cross_replica_average(shard_variance,
                                                         num_shards_per_group)

            # Group variance needs to also include the difference between shard_mean
            # and group_mean.
            mean_distance = tf.square(group_mean - shard_mean)
            group_variance += self._cross_replica_average(mean_distance,
                                                          num_shards_per_group)
            return (group_mean, group_variance)
        else:
            return (shard_mean, shard_variance)
  def compute_gradients(self, loss, var_list=None, **kwargs):
    """Compute gradients of "loss" for the variables in "var_list".

    This simply wraps the compute_gradients() from the real optimizer. The
    gradients will be aggregated in the apply_gradients() so that user can
    modify the gradients like clipping with per replica global norm if needed.
    The global norm with aggregated gradients can be bad as one replica's huge
    gradients can hurt the gradients from other replicas.

    Args:
      loss: A Tensor containing the value to minimize.
      var_list: Optional list or tuple of `tf.Variable` to update to minimize
        `loss`.  Defaults to the list of variables collected in the graph
        under the key `GraphKey.TRAINABLE_VARIABLES`.
      **kwargs: Keyword arguments for compute_gradients().

    Returns:
      A list of (gradient, variable) pairs.

    Raises:
      ValueError: If not within a tpu_shard_context.
    """
    num_shards = tpu_function.get_tpu_context().number_of_shards
    if num_shards is None:
      raise ValueError("CrossShardOptimizer must be used within a "
                       "tpu_shard_context.")
    if num_shards > 1 and self._reduction == losses.Reduction.MEAN:
      scale = 1.0 / num_shards
      loss *= scale
    return self._opt.compute_gradients(loss, var_list=var_list, **kwargs)
Exemple #8
0
    def _moments(self, inputs, reduction_axes, keep_dims):
        """Compute the mean and variance: it overrides the original _moments."""
        shard_mean, shard_variance = super(TpuBatchNormalization,
                                           self)._moments(inputs,
                                                          reduction_axes,
                                                          keep_dims=keep_dims)

        num_shards = tpu_function.get_tpu_context().number_of_shards or 1
        if num_shards <= 8:  # Skip cross_replica for 2x2 or smaller slices.
            num_shards_per_group = 1
        else:
            num_shards_per_group = max(8, num_shards // 8)
        tf.logging.info('TpuBatchNormalization with num_shards_per_group %s',
                        num_shards_per_group)
        if num_shards_per_group > 1:
            # Compute variance using: Var[X]= E[X^2] - E[X]^2.
            shard_square_of_mean = tf.math.square(shard_mean)
            shard_mean_of_square = shard_variance + shard_square_of_mean
            group_mean = self._cross_replica_average(shard_mean,
                                                     num_shards_per_group)
            group_mean_of_square = self._cross_replica_average(
                shard_mean_of_square, num_shards_per_group)
            group_variance = group_mean_of_square - tf.math.square(group_mean)
            return (group_mean, group_variance)
        else:
            return (shard_mean, shard_variance)
Exemple #9
0
    def compute_gradients(self, loss, var_list=None, **kwargs):
        """ This is adapted from:
        https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py#L100
        loss is a list of lists of outer length num_optimizers.
        Therefore, for each optimizer's loss, we multiply each loss by the
        scale
        """
        num_shards = tpu_function.get_tpu_context().number_of_shards
        if num_shards is None:
            logging.warning(
                "CrossShardMultiOptimizer should be used within a tpu_shard_context, but "
                "got unset number_of_shards. Assuming 1.")
            num_shards = 1

        subgroup_size = self._verify_and_get_subgroup_size(
            self._group_assignment, num_shards)

        if self._multi_mode:
            if not isinstance(loss, list):
                loss = [loss]
            scaled_losses = []
            for opt_idx, curr_loss in enumerate(loss):
                scaled_loss = self._rescale_loss(curr_loss, num_shards,
                                                 subgroup_size)
                scaled_losses.insert(opt_idx, scaled_loss)
        else:
            scaled_losses = self._rescale_loss(loss, num_shards, subgroup_size)

        return self._opt.compute_gradients(scaled_losses,
                                           var_list=var_list,
                                           **kwargs)
Exemple #10
0
 def _create_default_group_assignment():
     num_shards = tpu_function.get_tpu_context().number_of_shards
     if num_shards is None:
         logging.warning(
             "cross_replica_sum should be used within a tpu_shard_context, but "
             "got unset number_of_shards. Assuming 1.")
         num_shards = 1
     group_assignment = [list(range(num_shards))]
     return group_assignment
Exemple #11
0
def _replicated_optimizer(opt):
  """Wrap the optimizer `opt` with CrossShardOptimizer if applicable."""
  if tpu_function.get_tpu_context().number_of_shards == 1:
    return opt

  if isinstance(opt, keras_optimizers.TFOptimizer):
    return tpu_optimizer.CrossShardOptimizer(opt.optimizer)
  else:
    return KerasCrossShardOptimizer(opt)
Exemple #12
0
 def _create_default_group_assignment():
   num_shards = tpu_function.get_tpu_context().number_of_shards
   if num_shards is None:
     logging.warning(
         "cross_replica_sum should be used within a tpu_shard_context, but "
         "got unset number_of_shards. Assuming 1.")
     num_shards = 1
   group_assignment = [list(range(num_shards))]
   return group_assignment
Exemple #13
0
def _replicated_optimizer(opt):
    """Wrap the optimizer `opt` with CrossShardOptimizer if applicable."""
    if tpu_function.get_tpu_context().number_of_shards == 1:
        return opt

    if isinstance(opt, keras_optimizers.TFOptimizer):
        return tpu_optimizer.CrossShardOptimizer(opt.optimizer)
    else:
        return KerasCrossShardOptimizer(opt)
Exemple #14
0
    def __init__(self,
                 logdir,
                 namespace=None,
                 scope=None,
                 skip_summary=False,
                 global_step=None):
        """Initializes a `_ScopedSummary`.

    Args:
      logdir: String directory path for logging summaries.
      namespace: String namespace to append to the logdir. Can be shared with
        other `_ScopedSummary` objects.
      scope: String scope name.
      skip_summary: Whether to record summary ops.
      global_step: Global step `Tensor`.

    Returns:
      A `_ScopedSummary` instance.
    """

        assert logdir

        if scope == _DEFAULT_SCOPE:
            raise ValueError("scope cannot be 'default'.")

        lazy = False
        if tpu_function.get_tpu_context().number_of_shards:
            tf.logging.log_first_n(
                tf.logging.INFO,
                "Summaries will be created lazily to work with TPU.", 1)
            lazy = True

        self._lazy = lazy
        if namespace:
            logdir = os.path.join(logdir, namespace)
        if scope:
            logdir = os.path.join(logdir, scope)
        self._logdir = logdir
        self._namespace = namespace
        self._scope = scope
        self._additional_scope = None
        self._skip_summary = skip_summary
        self._summary_ops = []
        self._actual_summary_scalar_fn = tf.contrib.summary.scalar
        self._actual_summary_image_fn = tf.contrib.summary.image
        self._actual_summary_histogram_fn = tf.contrib.summary.histogram
        self._actual_summary_audio_fn = tf.contrib.summary.audio
        if global_step is None:
            global_step = tf.train.get_global_step()
        self._global_step = global_step
        self._lazy_summaries = []
        self._flush_op = {}
Exemple #15
0
 def _cross_replica_average(self, t, num_shards_per_group):
     """Calculates the average value of input tensor across TPU replicas."""
     num_shards = tpu_function.get_tpu_context().number_of_shards
     group_assignment = None
     if num_shards_per_group > 1:
         if num_shards % num_shards_per_group != 0:
             raise ValueError('num_shards: %d mod shards_per_group: %d, should be 0'
                              % (num_shards, num_shards_per_group))
         num_groups = num_shards // num_shards_per_group
         group_assignment = [[
             x for x in range(num_shards) if x // num_shards_per_group == y
         ] for y in range(num_groups)]
     return tpu_ops.cross_replica_sum(t, group_assignment) / tf.cast(
         num_shards_per_group, t.dtype)
Exemple #16
0
def cross_replica_mean(inputs, group_size=None):
  """Calculates the average value of inputs tensor across TPU replicas."""
  num_replicas = get_tpu_context().number_of_shards
  if not group_size:
    group_size = num_replicas
  if group_size == 1:
    return inputs
  if group_size != num_replicas:
    group_assignment = []
    assert num_replicas % group_size == 0
    for g in range(num_replicas // group_size):
      replica_ids = [g * group_size + i for i in range(group_size)]
      group_assignment.append(replica_ids)
  else:
    group_assignment = None
  return tf.contrib.tpu.cross_replica_sum(inputs, group_assignment) / tf.cast(
      group_size, inputs.dtype)
Exemple #17
0
def cross_replica_mean(tensor, name=None):
    """Takes mean value of a Tensor across all TPU cores.
  Args:
    tensor: Tensor to be synchronized.
    name: None or string. Name of Op.
  Returns:
    Average of Tensor across all TPU cores.
  Raises:
    ValueError: If called outside of TPU context.
  """
    with ops.name_scope(name, "cross_replica_mean", [tensor]):
        num_shards = tpu_function.get_tpu_context().number_of_shards
        if num_shards is None:
            raise ValueError(
                "Cannot take cross_replica_mean() outside of TPU Context.")
        if num_shards == 1:
            return tensor
        return tpu_ops.cross_replica_sum(tensor / num_shards)
Exemple #18
0
def cross_replica_mean(tensor, name=None):
  """Takes mean value of a Tensor across all TPU cores.

  Args:
    tensor: Tensor to be synchronized.
    name: None or string. Name of Op.

  Returns:
    Average of Tensor across all TPU cores.

  Raises:
    ValueError: If called outside of TPU context.
  """
  with ops.name_scope(name, "cross_replica_mean", [tensor]):
    num_shards = tpu_function.get_tpu_context().number_of_shards
    if num_shards is None:
      raise ValueError(
          "Cannot take cross_replica_mean() outside of TPU Context.")
    if num_shards == 1:
      return tensor
    return tpu_ops.cross_replica_sum(tensor / num_shards)
Exemple #19
0
def get_num_replicas():
    """Returns the number of replicas.

  If not operating in a supported replicated context this function will return
  1.
  """

    tf_replicator = get_tf_replicator()

    if tf_replicator:
        return tf_replicator.num_replicas_in_sync
    elif tf.distribute.has_strategy():
        return tf.distribute.get_strategy().num_replicas_in_sync
    else:
        # I'm assuming replicas and shards are always equal until someone tells me
        # different.
        num_replicas = tpu_function.get_tpu_context().number_of_shards
        if num_replicas:
            return num_replicas
        else:
            return 1
Exemple #20
0
 def _Moments(self, inputs, group_size):
   """Computes mean and variance over N,H,W dimensions in inputs."""
   counts, mean_ss, variance_ss, _, = tf.nn.sufficient_statistics(
       inputs, axes=[0, 1, 2], keep_dims=False)
   self.accumulators.counts.Update(counts)
   self.accumulators.mean_ss.Update(mean_ss)
   self.accumulators.variance_ss.Update(variance_ss)
   if py_utils.use_tpu() and group_size > 1:
     num_shards = tpu_function.get_tpu_context().number_of_shards
     assert num_shards >= group_size
     assert num_shards % group_size == 0
     num_groups = num_shards // group_size
     group_assignment = []
     for g in range(num_groups):
       replica_ids = [g * group_size + i for i in range(group_size)]
       group_assignment.append(replica_ids)
     counts *= group_size
     mean_ss = tf.contrib.tpu.cross_replica_sum(mean_ss, group_assignment)
     variance_ss = tf.contrib.tpu.cross_replica_sum(variance_ss,
                                                    group_assignment)
   mean, variance = tf.nn.normalize_moments(counts, mean_ss, variance_ss, None)
   return mean, variance
    def cross_replica_sum(x, group_assignment=None, name=None):
        """Sum the input tensor accorss replicas according to group_assignment.

    Args:
      x: The local tensor to the sum.
      group_assignment: Optional 2d int32 lists with shape [num_groups,
        num_replicas_per_group]. `group_assignment[i]` represents the replica
        ids in the ith subgroup.
      name: Optional op name.

    Returns:
      A `Tensor` which is summed across replicas.
    """
        if group_assignment is None:
            num_shards = tpu_function.get_tpu_context().number_of_shards
            if num_shards is None:
                logging.warning(
                    "cross_replica_sum should be used within a tpu_shard_context, but "
                    "got unset number_of_shards. Assuming 1.")
                num_shards = 1
            group_assignment = [list(range(num_shards))]

        return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
Exemple #22
0
def num_tpu_shards():
    """Get the number of TPU shards."""
    return tpu_function.get_tpu_context().number_of_shards
Exemple #23
0
    def CreateTpuFeeds(self):
        """Creates the TPU infeed queue from preprocessed batch."""
        p = self.params
        cluster = self.cluster
        num_tpu_hosts = cluster.num_tpu_hosts
        num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts
        tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host))
        tf.logging.info('num_devices_per_split {}'.format(
            cluster.num_devices_per_split))

        assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
        if (cluster.num_devices_per_split > num_cores_per_host
                and p.use_per_host_infeed):
            tf.logging.fatal(
                'Doesn\'t support per host infeed mode when '
                'num_devices_per_split({}) > num_cores_per_host({})'.format(
                    cluster.num_devices_per_split, num_cores_per_host))
        num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

        with py_utils.outside_all_rewrites():
            assert py_utils.use_tpu()
            assert not self._made_tpu_infeed

            shards = tpu_function.get_tpu_context(
            ).number_of_shards // num_infeed_hosts
            input_ops_list = []
            queues = []
            first_batch = None
            tpu_embedding_collection = tf.get_collection(
                py_utils.TPU_EMBEDDING)
            tpu_embedding = (tpu_embedding_collection[0]
                             if tpu_embedding_collection else None)

            tpu_embedding_input_keys = (
                tpu_embedding.feature_to_config_dict.keys()
                if tpu_embedding is not None else [])

            for task_id in range(num_infeed_hosts):
                host_device = '/task:{}/device:CPU:0'.format(task_id)
                with tf.device(host_device):
                    batch = self.GetPreprocessedInputBatch()
                    tpu_embedding_features = []
                    for tpu_embedding_input_key in tpu_embedding_input_keys:
                        tpu_embedding_feature = batch.pop(
                            tpu_embedding_input_key)
                        tpu_embedding_features.append(
                            (tpu_embedding_input_key, tpu_embedding_feature))

                    if first_batch is None:
                        first_batch = batch
                    flat_batch = batch.FlattenItems()

                    if tpu_embedding is not None:
                        enqueue_dict_per_core = [
                            {}
                        ] * tpu_embedding.num_cores_per_host
                        num_cores_per_host = tpu_embedding.num_cores_per_host
                        for tpu_embedding_input_key, tpu_embedding_feature in tpu_embedding_features:
                            tpu_embedding_feature_splitted = tf.split(
                                tpu_embedding_feature, num_cores_per_host)
                            for core, split in enumerate(
                                    tpu_embedding_feature_splitted):
                                enqueue_data = tpu_embedding_lib.EnqueueData(
                                    tf.squeeze(split, axis=[1]))
                                enqueue_dict_per_core[core][
                                    tpu_embedding_input_key] = enqueue_data
                        input_ops_list += tpu_embedding.generate_enqueue_ops(
                            enqueue_dict_per_core)

                    shapes, types = [], []
                    for k, x in flat_batch:
                        assert x.shape.is_fully_defined(), (
                            'Shape must be fully defined: %s: %s' % (k, x))
                        # TODO(cwhipkey): if it's a string (or other type not supported on
                        # TPU), drop it from feeding and on the other end add in an op that
                        # fails if used.
                        shapes.append(x.shape)
                        types.append(x.dtype)
                    q = tf.contrib.tpu.InfeedQueue(tuple_types=types,
                                                   tuple_shapes=shapes)
                    queues.append(q)
                    assert shards is not None
                    q.set_number_of_shards(shards)

                    if p.use_per_host_infeed:

                        # TODO(ylc/zhifengc): Add this to a policy module and test it.
                        def _tpu_ordinal_function(shard_index_in_host):
                            device_assignment = py_utils.GetTpuDeviceAssignment(
                            )
                            if device_assignment:
                                # We put both enqueue/dequeue ops at core 0 in each replica.
                                replica = device_assignment.lookup_replicas(
                                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                                return device_assignment.tpu_ordinal(
                                    replica=replica)
                            else:
                                return shard_index_in_host

                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            [v for _, v in flat_batch],
                            placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                            tpu_ordinal_function=_tpu_ordinal_function)
                    else:
                        input_ops = q.split_inputs_and_generate_enqueue_ops(
                            [v for _, v in flat_batch],
                            device_assignment=py_utils.GetTpuDeviceAssignment(
                            ))

                    input_ops_list += input_ops
            tf.logging.info('input_ops_list %s', input_ops_list)
            tpu_infeed_op = tf.group(*input_ops_list)
        self._made_tpu_infeed = True
        # Let trainer.py use multiple threads to drive the infeed op.
        for _ in range(p.tpu_infeed_parallism):
            tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

        with tf.device(tf.compat.v1.tpu.core(0)):
            tensors = queues[0].generate_dequeue_op()
        return first_batch.Pack(tensors)
Exemple #24
0
def _is_running_on_cpu():
  """Returns True if the current context is CPU model."""
  return tpu_function.get_tpu_context().number_of_shards is None
  def body_wrapper(*inputs):
    """Wrapper around `body` that handles infeed queues and control deps."""
    inputs = list(inputs)

    # Discards the dummy output added for arity-0 loops.
    if input_arity == 0:
      inputs = []

    # Runs `body` with the dequeue_ops appended.
    if infeed_queue:
      number_of_shards = tpu_function.get_tpu_context().number_of_shards
      if number_of_shards is None:
        raise ValueError("Can't build training loop with infeed when there is "
                         "no tpu_shard_context. Are you building a loop or "
                         "graph directly rather than from inside tpu.rewrite, "
                         "tpu.batch_parallel, tpu.shard, or tpu.replicate?")
      infeed_queue.set_number_of_shards(number_of_shards)
      dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()]
    else:
      dequeue_ops = []
    outputs = body(*(inputs + dequeue_ops))

    # If the computation only returned one value, make it a tuple.
    if not isinstance(outputs, (list, tuple)):
      outputs = (outputs,)

    outputs = [
        o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
        for o in outputs
    ]

    # Separates the returned Operations and Tensors.
    output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
    output_tensors = [o for o in outputs
                      if not isinstance(o, ops.Operation)]

    if outputs != output_tensors + output_operations:
      raise ValueError(
          "TPU training loop body must return zero or more Tensor values "
          "followed by zero or more Operations.")

    output_types = [op.dtype for op in output_tensors]
    if input_types != output_types:
      raise TypeError(
          "Mismatch between input types and output types for training loop "
          "body: {} vs {}".format(input_types, output_types))

    # Add the dequeue operations to output_operations to ensure they are run
    # by the loop, even if the programmer's loop body does not use them.
    output_operations += dequeue_ops

    # Add a dummy output, if needed.
    if not output_tensors:
      output_tensors = array_ops.constant(0)

    if output_operations:
      # TODO(phawkins): in principle this is too restrictive since it serializes
      # the training loop steps. In practice it does not matter since this loop
      # will be compiled by XLA.
      return control_flow_ops.tuple(output_tensors,
                                    control_inputs=output_operations)
    else:
      return output_tensors
Exemple #26
0
def standardize_batch(inputs,
                      is_training,
                      decay=0.999,
                      epsilon=1e-3,
                      data_format="NHWC",
                      use_moving_averages=True,
                      use_cross_replica_mean=None):
    """Adds TPU-enabled batch normalization layer.

  This version does not apply trainable scale or offset!
  It normalizes a tensor by mean and variance.

  Details on Batch Normalization can be found in "Batch Normalization:
  Accelerating Deep Network Training by Reducing Internal Covariate Shift",
  Ioffe S. and Szegedy C. 2015 [http://arxiv.org/abs/1502.03167].

  Note #1: This method computes the batch statistic across all TPU replicas,
  thus simulating the true batch norm in the distributed setting. If one wants
  to avoid the cross-replica communication set use_cross_replica_mean=False.

  Note #2: When is_training is True the moving_mean and moving_variance need
  to be updated in each training step. By default, the update_ops are placed
  in `tf.GraphKeys.UPDATE_OPS` and they need to be added as a dependency to
  the `train_op`. For example:

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    if update_ops:
      updates = tf.group(*update_ops)
      total_loss = control_flow_ops.with_dependencies([updates], total_loss)

  Note #3: Reasonable values for `decay` are close to 1.0, typically in the
  multiple-nines range: 0.999, 0.99, 0.9, etc. Lower the `decay` value (trying
  `decay`=0.9) if model experiences reasonably good training performance but
  poor validation and/or test performance.

  Args:
    inputs: A tensor with 2 or 4 dimensions, where the first dimension is
      `batch_size`. The normalization is over all but the last dimension if
      `data_format` is `NHWC`, and the second dimension if `data_format` is
      `NCHW`.
    is_training: Whether or not the layer is in training mode. In training
      mode it would accumulate the statistics of the moments into the
      `moving_mean` and `moving_variance` using an exponential moving average
      with the given `decay`. When is_training=False, these variables are not
      updated, and the precomputed values are used verbatim.
    decay: Decay for the moving averages. See notes above for reasonable
      values.
    epsilon: Small float added to variance to avoid dividing by zero.
    data_format: Input data format. NHWC or NCHW.
    use_moving_averages: If True keep moving averages of mean and variance that
      are used during inference. Otherwise use accumlators.
    use_cross_replica_mean: If True add operations to do computes batch norm
      statistics across all TPU cores. These ops are not compatible with other
      platforms. The default (None) will only add the operations if running
      on TPU.

  Returns:
    The normalized tensor with the same type and shape as `inputs`.
  """
    if data_format not in {"NCHW", "NHWC"}:
        raise ValueError(
            "Invalid data_format {}. Allowed: NCHW, NHWC.".format(data_format))
    if use_cross_replica_mean is None:
        # Default to global batch norm only on TPUs.
        use_cross_replica_mean = (
            tpu_function.get_tpu_context().number_of_shards is not None)
        logging.debug("Automatically determined use_cross_replica_mean=%s.",
                      use_cross_replica_mean)

    inputs = tf.convert_to_tensor(inputs)
    inputs_dtype = inputs.dtype
    inputs_shape = inputs.get_shape()

    num_channels = inputs.shape[-1].value
    if num_channels is None:
        raise ValueError("`C` dimension must be known but is None")

    inputs_rank = inputs_shape.ndims
    if inputs_rank is None:
        raise ValueError("Inputs %s has undefined rank" % inputs.name)
    elif inputs_rank not in [2, 4]:
        raise ValueError("Inputs %s has unsupported rank."
                         " Expected 2 or 4 but got %d" %
                         (inputs.name, inputs_rank))
    # Bring 2-D inputs into 4-D format.
    if inputs_rank == 2:
        new_shape = [-1, 1, 1, num_channels]
        if data_format == "NCHW":
            new_shape = [-1, num_channels, 1, 1]
        inputs = tf.reshape(inputs, new_shape)

    # Execute a distributed batch normalization
    axis = 1 if data_format == "NCHW" else 3
    inputs = tf.cast(inputs, tf.float32)
    reduction_axes = [i for i in range(4) if i != axis]
    if use_cross_replica_mean:
        mean, variance = tpu_ops.cross_replica_moments(inputs, reduction_axes)
    else:
        counts, mean_ss, variance_ss, _ = tf.nn.sufficient_statistics(
            inputs, reduction_axes, keep_dims=False)
        mean, variance = tf.nn.normalize_moments(counts,
                                                 mean_ss,
                                                 variance_ss,
                                                 shift=None)

    if use_moving_averages:
        mean, variance = _moving_moments_for_inference(mean=mean,
                                                       variance=variance,
                                                       is_training=is_training,
                                                       decay=decay)
    else:
        mean, variance = _accumulated_moments_for_inference(
            mean=mean, variance=variance, is_training=is_training)

    outputs = tf.nn.batch_normalization(inputs,
                                        mean=mean,
                                        variance=variance,
                                        offset=None,
                                        scale=None,
                                        variance_epsilon=epsilon)
    outputs = tf.cast(outputs, inputs_dtype)

    # Bring 2-D inputs back into 2-D format.
    if inputs_rank == 2:
        outputs = tf.reshape(outputs, [-1] + inputs_shape[1:].as_list())
    outputs.set_shape(inputs_shape)
    return outputs
def BuildOpt(hparams):
  """Constructs the optimizer.

  Args:
    hparams: An instance of tf.HParams, with these parameters:
    - batch_size
    - examples_per_epoch
    - learning_rate
    - learning_rate_decay_factor
    - model_weights_averaging
    - momentum
    - num_epochs_per_decay
    - optimizer
    - rmsprop_decay
    - use_avg_model_params

  Returns:
    opt: The optimizer.
  """
  logging.info('Hyperparameters: %s', hparams)
  batch_size = hparams.batch_size
  examples_per_epoch = hparams.examples_per_epoch
  learning_rate_decay_factor = hparams.learning_rate_decay_factor
  learning_rate = hparams.learning_rate
  model_weights_averaging = hparams.model_weights_averaging
  momentum = hparams.momentum
  num_epochs_per_decay = hparams.num_epochs_per_decay
  optimizer = hparams.optimizer
  rmsprop_decay = hparams.rmsprop_decay
  rmsprop_epsilon = hparams.rmsprop_epsilon
  adam_beta2 = hparams.get('adam_beta2', 0.999)
  adam_epsilon = hparams.get('adam_epsilon', 1e-8)
  use_avg_model_params = hparams.use_avg_model_params

  global_step = tf.train.get_or_create_global_step()

  # Configure the learning rate using an exponetial decay.
  decay_steps = int(examples_per_epoch / batch_size *
                    num_epochs_per_decay)

  learning_rate = tf.train.exponential_decay(
      learning_rate,
      global_step,
      decay_steps,
      learning_rate_decay_factor,
      staircase=True)
  if not tpu_function.get_tpu_context():
    tf.summary.scalar('Learning Rate', learning_rate)

  if optimizer == 'momentum':
    opt = tf.train.MomentumOptimizer(learning_rate, momentum)
  elif optimizer == 'rmsprop':
    opt = tf.train.RMSPropOptimizer(
        learning_rate,
        decay=rmsprop_decay,
        momentum=momentum,
        epsilon=rmsprop_epsilon)
  else:
    opt = tf.train.AdamOptimizer(
        learning_rate,
        beta1=momentum,
        beta2=adam_beta2,
        epsilon=adam_epsilon)

  if use_avg_model_params:
    # Callers of BuildOpt() with use_avg_model_params=True expect the
    # MovingAverageOptimizer to be the last optimizer returned by this function
    # so that the swapping_saver can be constructed from it.
    return contrib_opt.MovingAverageOptimizer(
        opt, average_decay=model_weights_averaging)

  return opt
Exemple #28
0
  def body_wrapper(*inputs):
    """Wrapper around `body` that handles infeed queues and control deps."""
    inputs = list(inputs)

    # Discards the dummy output added for arity-0 loops.
    if input_arity == 0:
      inputs = []

    # Runs `body` with the dequeue_ops appended.
    if infeed_queue:
      number_of_shards = tpu_function.get_tpu_context().number_of_shards
      if number_of_shards is None:
        raise ValueError("Can't build training loop with infeed when there is "
                         "no tpu_shard_context. Are you building a loop or "
                         "graph directly rather than from inside tpu.rewrite, "
                         "tpu.batch_parallel, tpu.shard, or tpu.replicate?")
      infeed_queue.set_number_of_shards(number_of_shards)
      dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()]
    else:
      dequeue_ops = []
    outputs = body(*(inputs + dequeue_ops))

    # If the computation only returned one value, make it a tuple.
    if not isinstance(outputs, (list, tuple)):
      outputs = (outputs,)

    outputs = [
        o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
        for o in outputs
    ]

    # Separates the returned Operations and Tensors.
    output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
    output_tensors = [o for o in outputs
                      if not isinstance(o, ops.Operation)]

    if outputs != output_tensors + output_operations:
      raise ValueError(
          "TPU training loop body must return zero or more Tensor values "
          "followed by zero or more Operations.")

    output_types = [op.dtype for op in output_tensors]
    if input_types != output_types:
      raise TypeError(
          "Mismatch between input types and output types for training loop "
          "body: {} vs {}".format(input_types, output_types))

    # Add the dequeue operations to output_operations to ensure they are run
    # by the loop, even if the programmer's loop body does not use them.
    output_operations += dequeue_ops

    # Add a dummy output, if needed.
    if not output_tensors:
      output_tensors = array_ops.constant(0)

    if output_operations:
      # TODO (phawkins): in principle this is too restrictive since it serializes id:1047 gh:1048
      # the training loop steps. In practice it does not matter since this loop
      # will be compiled by XLA.
      return control_flow_ops.tuple(output_tensors,
                                    control_inputs=output_operations)
    else:
      return output_tensors
Exemple #29
0
def on_tpu():
  """Returns True when building a TPU computation."""
  return tpu_function.get_tpu_context().number_of_shards is not None
Exemple #30
0
def on_tpu():
    """Returns True when building a TPU computation."""
    return tpu_function.get_tpu_context().number_of_shards is not None
Exemple #31
0
 def get_gradients(self, loss, params):
     num_shards = tpu_function.get_tpu_context().number_of_shards
     grads = super(KerasCrossShardOptimizer,
                   self).get_gradients(loss, params)
     return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads]
Exemple #32
0
def standardize_batch(inputs,
                      decay=0.999,
                      epsilon=1e-3,
                      data_format="NHWC",
                      use_moving_averages=True,
                      use_cross_replica_mean=None):
  """Adds TPU-enabled batch normalization layer.

  This version does not apply trainable scale or offset!
  It normalizes a tensor by mean and variance.

  Details on Batch Normalization can be found in "Batch Normalization:
  Accelerating Deep Network Training by Reducing Internal Covariate Shift",
  Ioffe S. and Szegedy C. 2015 [http://arxiv.org/abs/1502.03167].

  Note #1: This method computes the batch statistic across all TPU replicas,
  thus simulating the true batch norm in the distributed setting. If one wants
  to avoid the cross-replica communication set use_cross_replica_mean=False.

  Note #2: During training this will estimate the mean and variance using the
  current batch. For inference there are two options:
  a) Keep moving averages of the mean and variance during training by
     setting use_moving_averages=True.
  b) Set use_moving_averages=False to create acccumulators that will have to be
     filled with values for mean and variance after training. This can be done
     by doing forward passes and recording the mean/variance vectors.
  In both cases the inference behavior is activated when the current
  `NormModes`, as return by `get_norm_modes()`, sets update_bn_stats=False.

  Note #3: Reasonable values for `decay` are close to 1.0, typically in the
  multiple-nines range: 0.999, 0.99, 0.9, etc. Lower the `decay` value (trying
  `decay`=0.9) if model experiences reasonably good training performance but
  poor validation and/or test performance.

  Args:
    inputs: A tensor with 2 or 4 dimensions, where the first dimension is
      `batch_size`. The normalization is over all but the last dimension if
      `data_format` is `NHWC`, and the second dimension if `data_format` is
      values.
    decay: Decay rate to use for moving averages.
    epsilon: Small float added to variance to avoid dividing by zero.
    data_format: Input data format. NHWC or NCHW.
    use_moving_averages: If True keep moving averages of mean and variance that
      are used during inference. Otherwise use accumlators.
    use_cross_replica_mean: If True add operations to do computes batch norm
      statistics across all TPU cores. These ops are not compatible with other
      platforms. The default (None) will only add the operations if running
      on TPU.

  Returns:
    The normalized tensor with the same type and shape as `inputs`.
  """
  if data_format not in {"NCHW", "NHWC"}:
    raise ValueError(
        "Invalid data_format {}. Allowed: NCHW, NHWC.".format(data_format))
  if use_cross_replica_mean is None:
    # Default to global batch norm only on TPUs.
    use_cross_replica_mean = (get_tpu_context().number_of_shards is not None)

  inputs = tf.convert_to_tensor(inputs)
  inputs_dtype = inputs.dtype
  inputs_shape = inputs.get_shape()

  num_channels = inputs.shape[-1].value
  if num_channels is None:
    raise ValueError("`C` dimension must be known but is None")

  inputs_rank = inputs_shape.ndims
  if inputs_rank is None:
    raise ValueError("Inputs %s has undefined rank" % inputs.name)
  elif inputs_rank not in [2, 4]:
    raise ValueError(
        "Inputs %s has unsupported rank."
        " Expected 2 or 4 but got %d" % (inputs.name, inputs_rank))
  # Bring 2-D inputs into 4-D format.
  if inputs_rank == 2:
    new_shape = [-1, 1, 1, num_channels]
    if data_format == "NCHW":
      new_shape = [-1, num_channels, 1, 1]
    inputs = tf.reshape(inputs, new_shape)

  # Execute a distributed batch normalization
  axis = 1 if data_format == "NCHW" else 3
  inputs = tf.cast(inputs, tf.float32)
  reduction_axes = [i for i in range(4) if i != axis]
  if use_cross_replica_mean:
    mean, variance = cross_replica_moments(inputs, reduction_axes)
  else:
    counts, mean_ss, variance_ss, _ = tf.nn.sufficient_statistics(
        inputs, reduction_axes, keep_dims=False)
    mean, variance = tf.nn.normalize_moments(
        counts, mean_ss, variance_ss, shift=None)

  if use_moving_averages:
    mean, variance = _moving_means_of_moments_for_inference(
        mean=mean, variance=variance, decay=decay)
  else:
    mean, variance = _accumulated_moments_for_inference(
        mean=mean, variance=variance)

  outputs = tf.nn.batch_normalization(
      inputs,
      mean=mean,
      variance=variance,
      offset=None,
      scale=None,
      variance_epsilon=epsilon)
  outputs = tf.cast(outputs, inputs_dtype)

  # Bring 2-D inputs back into 2-D format.
  if inputs_rank == 2:
    outputs = tf.reshape(outputs, [-1] + inputs_shape[1:].as_list())
  outputs.set_shape(inputs_shape)
  return outputs
Exemple #33
0
 def get_gradients(self, loss, params):
   num_shards = tpu_function.get_tpu_context().number_of_shards
   grads = super(KerasCrossShardOptimizer, self).get_gradients(loss, params)
   return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads]
Exemple #34
0
def get_num_tpu_shards():
    return tpu_function.get_tpu_context().number_of_shards
Exemple #35
0
 def _fn(*args, **kwargs):
     if tpu_function.get_tpu_context().number_of_shards:
         return None
     return fn(*args, **kwargs)
Exemple #36
0
  def CreateTpuFeeds(self):
    """Creates the TPU infeed queue from preprocessed batch."""
    p = self.params
    cluster = cluster_factory.Current()
    num_tpu_hosts = cluster.num_tpu_hosts
    assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts)
    num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1

    with py_utils.outside_all_rewrites():
      assert py_utils.use_tpu()
      assert not self._made_tpu_infeed

      shards = tpu_function.get_tpu_context(
      ).number_of_shards // num_infeed_hosts
      input_ops_list = []
      queues = []
      first_batch = None
      for task_id in range(num_infeed_hosts):
        host_device = '/task:{}/device:CPU:0'.format(task_id)
        with tf.device(host_device):
          batch = self.GetPreprocessedInputBatch()
          if first_batch is None:
            first_batch = batch
          flat_batch = batch.FlattenItems()

          shapes, types = [], []
          for k, x in flat_batch:
            assert x.shape.is_fully_defined(), (
                'Shape must be fully defined: %s: %s' % (k, x))
            # TODO(cwhipkey): if it's a string (or other type not supported on
            # TPU), drop it from feeding and on the other end add in an op that
            # fails if used.
            shapes.append(x.shape)
            types.append(x.dtype)
          q = tf.contrib.tpu.InfeedQueue(tuple_types=types, tuple_shapes=shapes)
          queues.append(q)
          assert shards is not None
          q.set_number_of_shards(shards)

          if p.use_per_host_infeed:

            # TODO(ylc/zhifengc): Add this to a policy module and test it.
            def _tpu_ordinal_function(shard_index_in_host):
              device_assignment = py_utils.GetTpuDeviceAssignment()
              if device_assignment:
                # We put both enqueue/dequeue ops at core 0 in each replica.
                replica = device_assignment.lookup_replicas(
                    task_id, 0)[shard_index_in_host]  # pylint: disable=cell-var-from-loop
                return device_assignment.tpu_ordinal(replica=replica)
              else:
                return shard_index_in_host

            input_ops = q.split_inputs_and_generate_enqueue_ops(
                [v for _, v in flat_batch],
                placement_function=lambda x: host_device,  # pylint: disable=cell-var-from-loop
                tpu_ordinal_function=_tpu_ordinal_function)
          else:
            input_ops = q.split_inputs_and_generate_enqueue_ops(
                [v for _, v in flat_batch],
                device_assignment=py_utils.GetTpuDeviceAssignment())

          input_ops_list += input_ops
      tf.logging.info('input_ops_list %s', input_ops_list)
      tpu_infeed_op = tf.group(*input_ops_list)
    self._made_tpu_infeed = True
    # Let trainer.py use multiple threads to drive the infeed op.
    for _ in range(p.tpu_infeed_parallism):
      tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op)

    with tf.device(tf.contrib.tpu.core(0)):
      tensors = queues[0].generate_dequeue_op()
    return first_batch.Pack(tensors)
Exemple #37
0
import sys
import os
import argparse
import json
import re

import tensorflow as tf
from tensorflow.contrib import tpu
from tensorflow.contrib.cluster_resolver import TPUClusterResolver

# Get the TPU's location
tpu_cluster = TPUClusterResolver(tpu=['albert2']).get_master()
import numpy as np

from tensorflow.contrib.tpu.python.tpu import tpu_function
tpu_function.get_tpu_context().set_number_of_shards(1)

from train.modeling import GroverModel, GroverConfig, sample
from tokenization import tokenization

##### ignore tf deprecated warning temporarily
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
# tf.logging.set_verbosity(tf.logging.DEBUG)
# from tensorflow.python.util import deprecation
# deprecation._PRINT_DEPRECATION_WARNINGS = False
# try:
#     from tensorflow.python.util import module_wrapper as deprecation
# except ImportError:
#     from tensorflow.python.util import deprecation_wrapper as deprecation
# deprecation._PER_MODULE_WARNING_LIMIT = 0