コード例 #1
0
ファイル: batch_norm.py プロジェクト: lloydjie1/tensorpack
def get_sync_bn_mean_var(inputs, red_axis, sync_statistics):
    ctx = get_current_tower_context()
    batch_mean = tf.reduce_mean(inputs, axis=red_axis)
    batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis)

    TF_version = get_tf_version_tuple()

    if sync_statistics == 'nccl':
        num_dev = ctx.total
        if num_dev == 1:
            logger.warn(
                "BatchNorm(sync_statistics='nccl') is used with only one tower!"
            )
        else:
            assert TF_version >= (1, 10), \
                "Cross-GPU BatchNorm is only supported in TF>=1.10 ." \
                "Upgrade TF or apply this patch manually: https://github.com/tensorflow/tensorflow/pull/20360"

            if TF_version <= (1, 12):
                try:
                    from tensorflow.contrib.nccl.python.ops.nccl_ops import _validate_and_load_nccl_so  # deprecated
                except Exception:
                    pass
                else:
                    _validate_and_load_nccl_so()
                from tensorflow.contrib.nccl.ops import gen_nccl_ops  # deprecated
            else:
                from tensorflow.python.ops import gen_nccl_ops
            shared_name = re.sub('tower[0-9]+/', '',
                                 tf.get_variable_scope().name)
            batch_mean = gen_nccl_ops.nccl_all_reduce(
                input=batch_mean,
                reduction='sum',
                num_devices=num_dev,
                shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev)
            batch_mean_square = gen_nccl_ops.nccl_all_reduce(
                input=batch_mean_square,
                reduction='sum',
                num_devices=num_dev,
                shared_name=shared_name + '_NCCL_mean_square') * (1.0 /
                                                                  num_dev)
    elif sync_statistics == 'horovod':
        # Require https://github.com/uber/horovod/pull/331
        import horovod.tensorflow as hvd
        if hvd.size() == 1:
            logger.warn(
                "BatchNorm(sync_statistics='horovod') is used with only one process!"
            )
        else:
            import horovod
            hvd_version = tuple(map(int, horovod.__version__.split('.')[:3]))
            assert hvd_version >= (
                0, 13, 6), "sync_statistics=horovod needs horovod>=0.13.6 !"

            batch_mean = hvd.allreduce(batch_mean, average=True)
            batch_mean_square = hvd.allreduce(batch_mean_square, average=True)
    batch_var = batch_mean_square - tf.square(batch_mean)
    return batch_mean, batch_var
コード例 #2
0
def _all_sum_grad(op, grad):
    """The gradients for `all_sum`.

  Args:
    op: The `all_sum` `Operation` that we are differentiating.
    grad: Gradient with respect to the output of the `all_sum` op.

  Returns:
    The gradient with respect to the output of `all_sum`.

  Raises:
    LookupError: If `reduction` is not `sum`.
  """
    if op.get_attr('reduction') != 'sum':
        raise LookupError('No gradient defined for NcclAllReduce except sum.')

    _check_device(grad, expected=op.device)
    num_devices = op.get_attr('num_devices')
    shared_name = op.get_attr('shared_name') + '_grad'

    with ops.device(op.device):
        return gen_nccl_ops.nccl_all_reduce(input=grad,
                                            reduction='sum',
                                            num_devices=num_devices,
                                            shared_name=shared_name)
コード例 #3
0
ファイル: nccl_ops.py プロジェクト: AnishShah/tensorflow
def _all_sum_grad(op, grad):
  """The gradients for `all_sum`.

  Args:
    op: The `all_sum` `Operation` that we are differentiating.
    grad: Gradient with respect to the output of the `all_sum` op.

  Returns:
    The gradient with respect to the output of `all_sum`.

  Raises:
    LookupError: If `reduction` is not `sum`.
  """
  if op.get_attr('reduction') != b'sum':
    raise LookupError('No gradient defined for NcclAllReduce except sum.')

  _check_device(grad, expected=op.device)
  num_devices = op.get_attr('num_devices')
  shared_name = op.get_attr('shared_name') + b'_grad'

  with ops.device(op.device):
    return gen_nccl_ops.nccl_all_reduce(
        input=grad,
        reduction='sum',
        num_devices=num_devices,
        shared_name=shared_name)
コード例 #4
0
ファイル: nccl_utils.py プロジェクト: skang29/GANs
def nccl_device_sum(input_, tower_config):
    nccl_name = "NCCL" if not tower_config.is_test else "NCCL_TEST"

    shared_name = input_.name.replace(tower_config.name,
                                      tower_config.prefix.format(nccl_name))

    output_ = gen_nccl_ops.nccl_all_reduce(
        input=input_,
        reduction="sum",
        num_devices=tower_config.num_devices,
        shared_name=shared_name)

    return output_
コード例 #5
0
def get_sync_bn_mean_var(x, axis, num_dev):
    coef = tf.constant(np.float32(1.0 / num_dev), name="coef")
    shared_name = tf.get_variable_scope().name
    shared_name = '_'.join(shared_name.split('/')[-2:])
    with tf.device(x.device):
        batch_mean = tf.reduce_mean(x, axis=axis)
        batch_mean = gen_nccl_ops.nccl_all_reduce(
            input=batch_mean,
            reduction='sum',
            num_devices=num_dev,
            shared_name=shared_name + '_NCCL_mean') * coef
    with tf.device(x.device):
        batch_mean_square = tf.reduce_mean(tf.square(x), axis=axis)
        batch_mean_square = gen_nccl_ops.nccl_all_reduce(
            input=batch_mean_square,
            reduction='sum',
            num_devices=num_dev,
            shared_name=shared_name + '_NCCL_mean_square') * coef

    batch_var = batch_mean_square - tf.square(batch_mean)

    return batch_mean, batch_var
コード例 #6
0
        def branchTrue():
            '''
				update the batch mean and batch variance
			'''
            # only one GPU
            if GPUNumber == 1:
                batch_mean = tf.reduce_mean(inputs,
                                            axis=axes,
                                            name="batch_mean")
                batch_mean_square = tf.reduce_mean(tf.square(inputs),
                                                   axis=axes)
            # multi GPUs
            else:
                # avarage moving_mean and moving_var in multi GPUs
                shared_name = re.sub('tower[0-9]+/', '',
                                     tf.get_variable_scope().name)
                batch_mean = tf.reduce_mean(inputs, axis=axes)

                # Utilize NCCL
                batch_mean = gen_nccl_ops.nccl_all_reduce(
                    input=batch_mean,
                    reduction='sum',
                    num_devices=GPUNumber,
                    shared_name=shared_name + '_NCCL_mean') * (1.0 / GPUNumber)
                batch_mean_square = tf.reduce_mean(tf.square(inputs),
                                                   axis=axes)
                batch_mean_square = gen_nccl_ops.nccl_all_reduce(
                    input=batch_mean_square,
                    reduction='sum',
                    num_devices=GPUNumber,
                    shared_name=shared_name +
                    '_NCCL_mean_square') * (1.0 / GPUNumber)

            batch_var = batch_mean_square - tf.square(batch_mean)

            outputs = tf.nn.batch_normalization(inputs, batch_mean, batch_var,
                                                beta, gamma, epsilon)

            return outputs, batch_mean, batch_var
コード例 #7
0
ファイル: nccl_ops.py プロジェクト: ExpLife0011/JuusanKoubou
def _apply_all_reduce(reduction_op, tensors):
    if not tensors:
        raise ValueError('Must pass >0 tensors to all reduce operations')
    shared_name = _get_shared_name()
    res = []
    for t in tensors:
        if not device.canonical_name(t.device):
            raise ValueError(
                'Device assignment required for nccl collective ops')
        with ops.device(t.device):
            res.append(
                gen_nccl_ops.nccl_all_reduce(t,
                                             reduction=reduction_op,
                                             num_devices=len(tensors),
                                             shared_name=shared_name))
    return res
コード例 #8
0
ファイル: nccl_ops.py プロジェクト: AlbertXiebnu/tensorflow
def _apply_all_reduce(reduction_op, tensors):
  if not tensors:
    raise ValueError('Must pass >0 tensors to all reduce operations')
  shared_name = _get_shared_name()
  res = []
  for t in tensors:
    if not device.canonical_name(t.device):
      raise ValueError('Device assignment required for nccl collective ops')
    with ops.device(t.device):
      res.append(
          gen_nccl_ops.nccl_all_reduce(
              t,
              reduction=reduction_op,
              num_devices=len(tensors),
              shared_name=shared_name))
  return res
コード例 #9
0
def _apply_all_reduce(reduction, tensors):
    """Helper function for all_* functions."""
    if not tensors:
        raise ValueError('Must pass >0 tensors to all reduce operations')
    _check_graph_mode()

    shared_name = _get_shared_name()
    res = []

    for t in tensors:
        _check_device(t)
        with ops.device(t.device):
            res.append(
                gen_nccl_ops.nccl_all_reduce(input=t,
                                             reduction=reduction,
                                             num_devices=len(tensors),
                                             shared_name=shared_name))

    return res
コード例 #10
0
ファイル: nccl_ops.py プロジェクト: AnishShah/tensorflow
def _apply_all_reduce(reduction, tensors):
  """Helper function for all_* functions."""
  if not tensors:
    raise ValueError('Must pass >0 tensors to all reduce operations')
  _validate_and_load_nccl_so()

  shared_name = _get_shared_name()
  res = []

  for t in tensors:
    _check_device(t)
    with ops.device(t.device):
      res.append(
          gen_nccl_ops.nccl_all_reduce(
              input=t,
              reduction=reduction,
              num_devices=len(tensors),
              shared_name=shared_name))

  return res
コード例 #11
0
    def call(self, inputs):
        do_sync = self.sync and self.training
        if not do_sync:
            with tf.variable_scope(self.sc,
                                   values=[inputs],
                                   reuse=True,
                                   auxiliary_name_scope=False) as sc:
                layer = tf.layers.BatchNormalization(
                    momentum=self.decay,
                    epsilon=self.epsilon,
                    center=self.center,
                    scale=self.scale,
                    beta_initializer=self.beta_initializer,
                    gamma_initializer=self.gamma_initializer,
                    moving_mean_initializer=self.moving_mean_initializer,
                    moving_variance_initializer=self.
                    moving_variance_initializer,
                    gamma_regularizer=self.gamma_regularizer,
                    beta_regularizer=self.beta_regularizer,
                    trainable=self.trainable,
                    name=sc.name,
                    dtype=inputs.dtype.base_dtype,
                    _scope=sc,
                    _reuse=True)
                ret = layer.apply(inputs, training=self.training)
        else:
            num_dev = len(tf_utils.get_available_gpus)
            if tf_utils.get_tf_version_tuple <= (1, 12):
                try:
                    from tensorflow.contrib.nccl.python.ops.nccl_ops \
                        import _validate_and_load_nccl_so
                except Exception:
                    pass
                else:
                    _validate_and_load_nccl_so()
                from tensorflow.contrib.nccl.ops import gen_nccl_ops
            else:
                from tensorflow.python.ops import gen_nccl_ops
            batch_mean = tf.reduce_mean(inputs, axis=[0, 1, 2])
            batch_mean_square = tf.reduce_mean(tf.square(inputs),
                                               axis=[0, 1, 2])
            shared_name = re.sub('tower[0-9]+/', '',
                                 tf.get_variable_scope().name)
            batch_mean = gen_nccl_ops.nccl_all_reduce(
                input=batch_mean,
                reduction='sum',
                num_devices=num_dev,
                shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev)
            batch_mean_square = gen_nccl_ops.nccl_all_reduce(
                input=batch_mean_square,
                reduction='sum',
                num_devices=num_dev,
                shared_name=shared_name + '_NCCL_var') * (1.0 / num_dev)
            batch_var = batch_mean_square - tf.square(batch_mean)

            ret = tf.nn.batch_normalization(inputs,
                                            mean=batch_mean,
                                            variance=batch_var,
                                            offset=self.beta,
                                            scale=self.gamma,
                                            variance_epsilon=self.epsilon)

            update_moving_mean = moving_averages.assign_moving_average(
                self.moving_mean, batch_mean, self.decay, zero_debias=False)
            update_moving_var = moving_averages.assign_moving_average(
                self.moving_variance, batch_var, self.decay, zero_debias=False)
            with tf.control_dependencies(
                [update_moving_mean, update_moving_var]):
                ret = tf.identity(ret)

        if self.activation is not None:
            ret = activation(ret)
        return slim.utils.collect_named_outputs(self.outputs_collections,
                                                self.sc.name, ret)
コード例 #12
0
ファイル: ops.py プロジェクト: q7800067/DSNet
def sync_batch_norm(inputs,
                    decay=0.999,
                    center=True,
                    scale=False,
                    epsilon=0.001,
                    activation_fn=None,
                    updates_collections=tf.GraphKeys.UPDATE_OPS,
                    is_training=True,
                    reuse=None,
                    variables_collections=None,
                    outputs_collections=None,
                    trainable=True,
                    scope=None,
                    num_dev=1):
  '''
  num_dev is how many gpus you use.
  '''
  

  from tensorflow.contrib.nccl.ops import gen_nccl_ops
  from tensorflow.contrib.framework import add_model_variable

  red_axises = [0, 1, 2]
  num_outputs = inputs.get_shape().as_list()[-1]

  if scope is None:
    scope = 'BatchNorm'

  layer_variable_getter = _build_variable_getter()
  with variable_scope.variable_scope(
      scope,
      'BatchNorm',
      reuse=reuse,
      custom_getter=layer_variable_getter) as sc:

    gamma = tf.get_variable(name='gamma', shape=[num_outputs], dtype=tf.float32,
                            initializer=tf.constant_initializer(1.0), trainable=trainable,
                            collections=variables_collections)

    beta  = tf.get_variable(name='beta', shape=[num_outputs], dtype=tf.float32,
                            initializer=tf.constant_initializer(0.0), trainable=trainable,
                            collections=variables_collections)

    moving_mean = tf.get_variable(name='moving_mean', shape=[num_outputs], dtype=tf.float32,
                                initializer=tf.constant_initializer(0.0), trainable=False,
                                collections=variables_collections)
                                
    moving_var = tf.get_variable(name='moving_variance', shape=[num_outputs], dtype=tf.float32,
                                initializer=tf.constant_initializer(1.0), trainable=False,
                                collections=variables_collections)

    if is_training and trainable:
      
      if num_dev == 1:
        mean, var = tf.nn.moments(inputs, red_axises)
      else:
        shared_name = tf.get_variable_scope().name
        batch_mean        = tf.reduce_mean(inputs, axis=red_axises)
        batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axises)
        batch_mean        = gen_nccl_ops.nccl_all_reduce(
          input=batch_mean,
          reduction='sum',
          num_devices=num_dev,
          shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev)
        batch_mean_square = gen_nccl_ops.nccl_all_reduce(
          input=batch_mean_square,
          reduction='sum',
          num_devices=num_dev,
          shared_name=shared_name + '_NCCL_mean_square') * (1.0 / num_dev)
        mean              = batch_mean
        var               = batch_mean_square - tf.square(batch_mean)
      outputs = tf.nn.batch_normalization(inputs, mean, var, beta, gamma, epsilon)

      if int(outputs.device[-1])== 0:
        update_moving_mean_op = tf.assign(moving_mean, moving_mean * decay + mean * (1 - decay))
        update_moving_var_op  = tf.assign(moving_var,  moving_var  * decay + var  * (1 - decay))
        add_model_variable(moving_mean)
        add_model_variable(moving_var)
        
        if updates_collections is None:
          with tf.control_dependencies([update_moving_mean_op, update_moving_var_op]):
            outputs = tf.identity(outputs)
        else:
          ops.add_to_collections(updates_collections, update_moving_mean_op)
          ops.add_to_collections(updates_collections, update_moving_var_op)
          outputs = tf.identity(outputs)
      else:
        outputs = tf.identity(outputs)

    else:
      outputs,_,_ = nn.fused_batch_norm(inputs, gamma, beta, mean=moving_mean, variance=moving_var, epsilon=epsilon, is_training=False)

    if activation_fn is not None:
      outputs = activation_fn(outputs)

    return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
コード例 #13
0
ファイル: batch_norm.py プロジェクト: yulinliu101/tensorpack
def BatchNorm(inputs,
              axis=None,
              training=None,
              momentum=0.9,
              epsilon=1e-5,
              center=True,
              scale=True,
              beta_initializer=tf.zeros_initializer(),
              gamma_initializer=tf.ones_initializer(),
              virtual_batch_size=None,
              data_format='channels_last',
              internal_update=False,
              sync_statistics=None):
    """
    Almost equivalent to `tf.layers.batch_normalization`, but different (and more powerful)
    in the following:

    1. Accepts an alternative `data_format` option when `axis` is None. For 2D input, this argument will be ignored.
    2. Default value for `momentum` and `epsilon` is different.
    3. Default value for `training` is automatically obtained from tensorpack's `TowerContext`, but can be overwritten.
    4. Support the `internal_update` option, which cover more use cases than the standard collection-based update.
    5. Support the `sync_statistics` option, which is very useful in small-batch models.

    Args:
        internal_update (bool): if False, add EMA update ops to
          `tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer by control dependencies.
          They are very similar in speed, but `internal_update=True` is recommended and can be helpful when:

          1. BatchNorm is used inside dynamic control flow.
             The collection-based update does not support dynamic control flows.
          2. BatchNorm layer is sometimes unused (e.g., when you have two networks to train alternatively).
             Putting all update ops into a single collection will waste a lot of compute.

          Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/14699
        sync_statistics (str or None): one of None, "nccl", or "horovod".

          By default (None), it uses statistics of the input tensor to normalize.
          This is the standard way BatchNorm was done in most frameworks.

          When set to "nccl", this layer must be used under tensorpack's multi-GPU trainers.
          It uses the aggregated statistics of the whole batch (across all GPUs) to normalize.

          When set to "horovod", this layer must be used under tensorpack's :class:`HorovodTrainer`.
          It uses the aggregated statistics of the whole batch (across all MPI ranks) to normalize.
          Note that on single machine this is significantly slower than the "nccl" implementation.

          If not None, per-GPU E[x] and E[x^2] among all GPUs are averaged to compute
          global mean & variance. Therefore each GPU needs to have the same batch size.

          The synchronization is based on the current variable scope + the name of the layer
          (`BatchNorm('name', input)`). Therefore, you need to make sure that:

          1. The BatchNorm layer on different GPUs needs to have the same name, so that
             statistics can be synchronized. If names do not match, this layer will hang.
          2. Different BatchNorm layers in one tower cannot share the same name.
          3. A BatchNorm layer needs to be executed for the same number of times by all GPUs.
             If different GPUs execute one BatchNorm layer for different number of times
             (e.g., if some GPUs do not execute it), this layer may hang.

          This option only has effect in standard training mode.

          This option is also known as "Cross-GPU BatchNorm" as mentioned in:
          `MegDet: A Large Mini-Batch Object Detector <https://arxiv.org/abs/1711.07240>`_.
          Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/18222.

          When `sync_statistics` is enabled, `internal_update` will be set to True automatically.
          This is to avoid running `UPDATE_OPS`, which requires synchronization.

    Variable Names:

    * ``beta``: the bias term. Will be zero-inited by default.
    * ``gamma``: the scale term. Will be one-inited by default.
    * ``mean/EMA``: the moving average of mean.
    * ``variance/EMA``: the moving average of variance.

    Note:
        Combinations of ``training`` and ``ctx.is_training``:

        * ``training == ctx.is_training``: standard BN, EMA are maintained during training
          and used during inference. This is the default.
        * ``training and not ctx.is_training``: still use batch statistics in inference.
        * ``not training and ctx.is_training``: use EMA to normalize in
          training. This is useful when you load a pre-trained BN and
          don't want to fine tune the EMA. EMA will not be updated in
          this case.
    """
    # parse shapes
    data_format = get_data_format(data_format, keras_mode=False)
    shape = inputs.get_shape().as_list()
    ndims = len(shape)
    assert ndims in [2, 4], ndims
    if sync_statistics is not None:
        sync_statistics = sync_statistics.lower()
    assert sync_statistics in [None, 'nccl', 'horovod'], sync_statistics

    if axis is None:
        if ndims == 2:
            axis = 1
        else:
            axis = 1 if data_format == 'NCHW' else 3
    assert axis in [1, 3], axis
    num_chan = shape[axis]

    # parse training/ctx
    ctx = get_current_tower_context()
    if training is None:
        training = ctx.is_training
    training = bool(training)
    TF_version = get_tf_version_tuple()
    freeze_bn_backward = not training and ctx.is_training
    if freeze_bn_backward:
        assert TF_version >= (1, 4), \
            "Fine tuning a BatchNorm model with fixed statistics needs TF>=1.4!"
        if ctx.is_main_training_tower:  # only warn in first tower
            logger.warn(
                "[BatchNorm] Using moving_mean/moving_variance in training.")
        # Using moving_mean/moving_variance in training, which means we
        # loaded a pre-trained BN and only fine-tuning the affine part.

    if sync_statistics is None or not (training and ctx.is_training):
        coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
        with rename_get_variable({
                'moving_mean': 'mean/EMA',
                'moving_variance': 'variance/EMA'
        }):
            tf_args = dict(
                axis=axis,
                momentum=momentum,
                epsilon=epsilon,
                center=center,
                scale=scale,
                beta_initializer=beta_initializer,
                gamma_initializer=gamma_initializer,
                # https://github.com/tensorflow/tensorflow/issues/10857#issuecomment-410185429
                fused=(ndims == 4 and axis in [1, 3]
                       and not freeze_bn_backward),
                _reuse=tf.get_variable_scope().reuse)
            if TF_version >= (1, 5):
                tf_args['virtual_batch_size'] = virtual_batch_size
            else:
                assert virtual_batch_size is None, "Feature not supported in this version of TF!"
            use_fp16 = inputs.dtype == tf.float16
            if use_fp16:
                # non-fused does not support fp16; fused does not support all layouts.
                # we made our best guess here
                tf_args['fused'] = True
            layer = tf.layers.BatchNormalization(**tf_args)
            xn = layer.apply(inputs,
                             training=training,
                             scope=tf.get_variable_scope())

        # maintain EMA only on one GPU is OK, even in replicated mode.
        # because during training, EMA isn't used
        if ctx.is_main_training_tower:
            for v in layer.non_trainable_variables:
                if isinstance(v, tf.Variable):
                    tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)
        if not ctx.is_main_training_tower or internal_update:
            restore_collection(coll_bk)

        if training and internal_update:
            assert layer.updates
            with tf.control_dependencies(layer.updates):
                ret = tf.identity(xn, name='output')
        else:
            ret = tf.identity(xn, name='output')

        vh = ret.variables = VariableHolder(
            moving_mean=layer.moving_mean,
            mean=layer.moving_mean,  # for backward-compatibility
            moving_variance=layer.moving_variance,
            variance=layer.moving_variance)  # for backward-compatibility
        if scale:
            vh.gamma = layer.gamma
        if center:
            vh.beta = layer.beta
    else:
        red_axis = [0] if ndims == 2 else (
            [0, 2, 3] if axis == 1 else [0, 1, 2])

        new_shape = None  # don't need to reshape unless ...
        if ndims == 4 and axis == 1:
            new_shape = [1, num_chan, 1, 1]

        batch_mean = tf.reduce_mean(inputs, axis=red_axis)
        batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis)

        if sync_statistics == 'nccl':
            num_dev = ctx.total
            if num_dev == 1:
                logger.warn(
                    "BatchNorm(sync_statistics='nccl') is used with only one tower!"
                )
            else:
                assert six.PY2 or TF_version >= (1, 10), \
                    "Cross-GPU BatchNorm is only supported in TF>=1.10 ." \
                    "Upgrade TF or apply this patch manually: https://github.com/tensorflow/tensorflow/pull/20360"

                if TF_version <= (1, 12):
                    try:
                        from tensorflow.contrib.nccl.python.ops.nccl_ops import _validate_and_load_nccl_so
                    except Exception:
                        pass
                    else:
                        _validate_and_load_nccl_so()
                    from tensorflow.contrib.nccl.ops import gen_nccl_ops
                else:
                    from tensorflow.python.ops import gen_nccl_ops
                shared_name = re.sub('tower[0-9]+/', '',
                                     tf.get_variable_scope().name)
                batch_mean = gen_nccl_ops.nccl_all_reduce(
                    input=batch_mean,
                    reduction='sum',
                    num_devices=num_dev,
                    shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev)
                batch_mean_square = gen_nccl_ops.nccl_all_reduce(
                    input=batch_mean_square,
                    reduction='sum',
                    num_devices=num_dev,
                    shared_name=shared_name + '_NCCL_mean_square') * (1.0 /
                                                                      num_dev)
        elif sync_statistics == 'horovod':
            # Require https://github.com/uber/horovod/pull/331
            import horovod.tensorflow as hvd
            if hvd.size() == 1:
                logger.warn(
                    "BatchNorm(sync_statistics='horovod') is used with only one process!"
                )
            else:
                import horovod
                hvd_version = tuple(map(int, horovod.__version__.split('.')))
                assert hvd_version >= (
                    0, 13,
                    6), "sync_statistics=horovod needs horovod>=0.13.6 !"

                batch_mean = hvd.allreduce(batch_mean, average=True)
                batch_mean_square = hvd.allreduce(batch_mean_square,
                                                  average=True)
        batch_var = batch_mean_square - tf.square(batch_mean)
        batch_mean_vec = batch_mean
        batch_var_vec = batch_var

        beta, gamma, moving_mean, moving_var = get_bn_variables(
            num_chan, scale, center, beta_initializer, gamma_initializer)
        if new_shape is not None:
            batch_mean = tf.reshape(batch_mean, new_shape)
            batch_var = tf.reshape(batch_var, new_shape)
            # Using fused_batch_norm(is_training=False) is actually slightly faster,
            # but hopefully this call will be JITed in the future.
            xn = tf.nn.batch_normalization(inputs, batch_mean, batch_var,
                                           tf.reshape(beta, new_shape),
                                           tf.reshape(gamma, new_shape),
                                           epsilon)
        else:
            xn = tf.nn.batch_normalization(inputs, batch_mean, batch_var, beta,
                                           gamma, epsilon)

        if ctx.is_main_training_tower:
            ret = update_bn_ema(xn, batch_mean_vec, batch_var_vec, moving_mean,
                                moving_var, momentum)
        else:
            ret = tf.identity(xn, name='output')

        vh = ret.variables = VariableHolder(
            moving_mean=moving_mean,
            mean=moving_mean,  # for backward-compatibility
            moving_variance=moving_var,
            variance=moving_var)  # for backward-compatibility
        if scale:
            vh.gamma = gamma
        if center:
            vh.beta = beta
    return ret
コード例 #14
0
ファイル: batch_norm.py プロジェクト: yulilili/tensorpack
def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
              center=True, scale=True,
              beta_initializer=tf.zeros_initializer(),
              gamma_initializer=tf.ones_initializer(),
              virtual_batch_size=None,
              data_format='channels_last',
              ema_update='default',
              sync_statistics=None,
              internal_update=None):
    """
    A more powerful version of `tf.layers.batch_normalization`. It differs from
    the offical one in the following aspects:

    1. Accepts an alternative ``data_format`` option when ``axis`` is None. For 2D input, this argument will be ignored.
    2. Default value for ``momentum`` and ``epsilon`` is different.
    3. Default value for ``training`` is automatically obtained from tensorpack's ``TowerContext``.
       User-provided value can overwrite this behavior.
    4. Support the ``ema_update`` option, which covers broader use cases than the standard EMA update.
    5. Support the ``sync_statistics`` option, which implements "SyncBN" and is very useful in small-batch models.

    Args:
        training (bool): if True, use per-batch statistics to normalize. Otherwise, use stored EMA
            to normalize. By default, it is equal to `get_current_tower_context().is_training`.
            This is not a good argument name, but it is what the Tensorflow layer uses.
        ema_update (str): Only effective when ``training=True``. It has the following options:

          * "default": same as "collection". Because this is the default behavior in tensorflow.
          * "skip": do not update EMA. This can be useful when you reuse a batch norm layer in several places
            but do not want them to all update your EMA.
          * "collection": Add EMA update ops to collection `tf.GraphKeys.UPDATE_OPS`.
            The ops in the collection will be run automatically by the callback :class:`RunUpdateOps`, along with
            your training iterations. This can waste compute if your training iterations do not always depend
            on the BatchNorm layer.
          * "internal": EMA is updated inside this layer itself by control dependencies.
            In common cases, it has similar speed to "collection". But it covers more cases, e.g.:

            1. BatchNorm is used inside dynamic control flow.
               The collection-based update does not support dynamic control flows.
            2. BatchNorm layer is sometimes unused (e.g., in GANs you have two networks to train alternatively).
               Putting all update ops into a single collection will waste a lot of compute.
            3. Other part of the model relies on the "updated" EMA. The collection-based method does not update
               EMA immediately.

            Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/14699
        sync_statistics (str or None): one of None, "nccl", or "horovod". It determines how to compute the
          "per-batch statistics" when ``training==True``.

          * None: it uses statistics of the input tensor to normalize during training.
            This is the standard way BatchNorm was implemented in most frameworks.

          * "nccl": this layer must be used under tensorpack's multi-GPU trainers.
            It uses the aggregated statistics of the whole batch (across all GPUs) to normalize.

          * "horovod": this layer must be used under tensorpack's :class:`HorovodTrainer`.
            It uses the aggregated statistics of the whole batch (across all MPI ranks) to normalize.
            Note that on single machine this is significantly slower than the "nccl" implementation.

          When not None, each GPU computes its own E[x] and E[x^2],
          which are then averaged among all GPUs to compute global mean & variance.
          Therefore each GPU needs to have the same batch size.

          The synchronization is based on the current variable scope + the name of the layer
          (`BatchNorm('name', input)`). Therefore, you need to make sure that:

          1. The BatchNorm layer on different GPUs needs to have the same name, so that
             statistics can be synchronized. If names do not match, this layer will hang.
          2. A BatchNorm layer cannot be reused within one tower.
          3. A BatchNorm layer needs to be executed for the same number of times by all GPUs.
             If different GPUs execute one BatchNorm layer for different number of times
             (e.g., if some GPUs do not execute it), this layer may hang.

          This option is also known as "SyncBN" or "Cross-GPU BatchNorm" as mentioned in:
          `MegDet: A Large Mini-Batch Object Detector <https://arxiv.org/abs/1711.07240>`_.
          Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/18222.

          When `sync_statistics` is enabled, `ema_update` is set to "internal" automatically.
          This is to avoid running `UPDATE_OPS`, which requires synchronization.

        internal_update: deprecated option. Don't use.

    Variable Names:

    * ``beta``: the bias term. Will be zero-inited by default.
    * ``gamma``: the scale term. Will be one-inited by default.
    * ``mean/EMA``: the moving average of mean.
    * ``variance/EMA``: the moving average of variance.

    Note:
        This layer is more flexible than the standard "BatchNorm" layer and provides more features:

        1. No matter whether you're doing training or not, you can set the ``training`` argument
           to use batch statistics or EMA statistics.
           i.e., you can use batch statistics during inference, or use EMA statistics during training.
           Using EMA statistics in training is useful when you load a pre-trained BN and
           don't want to update it.
        2. As long as `training=True`, `sync_statistics` and `ema_update` option will take effect.
    """
    # parse training/ctx
    ctx = get_current_tower_context()
    if training is None:
        training = ctx.is_training
    training = bool(training)

    # parse shapes
    data_format = get_data_format(data_format, keras_mode=False)
    shape = inputs.get_shape().as_list()
    ndims = len(shape)
    assert ndims in [2, 4], ndims
    if sync_statistics is not None:
        sync_statistics = sync_statistics.lower()
    assert sync_statistics in [None, 'nccl', 'horovod'], sync_statistics

    assert ema_update in ["default", "collection", "internal", "skip"]
    if internal_update is not None:
        log_deprecated("BatchNorm(internal_update=)", "Use ema_update='internal' instead!", "2020-01-01")
        assert ema_update == 'default', \
            "Do not use internal_update and ema_update together! internal_update is deprecated"
        ema_update = "internal" if internal_update else "collection"
    if ema_update == "default":
        ema_update = "collection"
    # Logic:
    # 1. EMA update is possible only when we compute batch statistics (training=True)
    # 2. We know that in training, non-main training tower does not need EMA update
    #    We don't know about what to do in prediction context, so be conservative and do the update.
    # 3. User can explicit disable update by "skip".
    do_ema_update = training and \
        (ctx.is_main_training_tower or not ctx.is_training) \
        and (ema_update != "skip")

    if axis is None:
        if ndims == 2:
            axis = 1
        else:
            axis = 1 if data_format == 'NCHW' else 3
    assert axis in [1, 3], axis
    num_chan = shape[axis]

    TF_version = get_tf_version_tuple()

    freeze_bn_backward = not training and ctx.is_training
    if freeze_bn_backward:
        assert TF_version >= (1, 4), \
            "Fine tuning a BatchNorm model with fixed statistics needs TF>=1.4!"
        if ctx.is_main_training_tower:  # only warn in first tower
            log_once("Some BatchNorm layer uses moving_mean/moving_variance in training.", func='warn')
        # Using moving_mean/moving_variance in training, which means we
        # loaded a pre-trained BN and only fine-tuning the affine part.

    do_sync_bn = (sync_statistics is not None) and training

    if not do_sync_bn:
        # Use the builtin layer for anything except for sync-bn
        coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
        with rename_get_variable(
                {'moving_mean': 'mean/EMA',
                 'moving_variance': 'variance/EMA'}):
            tf_args = dict(
                axis=axis,
                momentum=momentum, epsilon=epsilon,
                center=center, scale=scale,
                beta_initializer=beta_initializer,
                gamma_initializer=gamma_initializer,
                # https://github.com/tensorflow/tensorflow/issues/10857#issuecomment-410185429
                fused=(ndims == 4 and axis in [1, 3] and not freeze_bn_backward),
                _reuse=tf.get_variable_scope().reuse)
            if TF_version >= (1, 5):
                tf_args['virtual_batch_size'] = virtual_batch_size
            else:
                assert virtual_batch_size is None, "Feature not supported in this version of TF!"
            use_fp16 = inputs.dtype == tf.float16
            if use_fp16:
                # non-fused does not support fp16; fused does not support all layouts.
                # we made our best guess here
                tf_args['fused'] = True
            layer = tf.layers.BatchNormalization(**tf_args)
            xn = layer.apply(inputs, training=training, scope=tf.get_variable_scope())

        # Add EMA variables to the correct collection
        if ctx.is_main_training_tower:
            for v in layer.non_trainable_variables:
                if isinstance(v, tf.Variable):
                    tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)

        if not do_ema_update:
            restore_collection(coll_bk)
        if do_ema_update and ema_update == "internal":
            # Implement "internal" update.
            restore_collection(coll_bk)
            assert layer.updates
            with tf.control_dependencies(layer.updates):
                ret = tf.identity(xn, name='output')
        else:
            ret = tf.identity(xn, name='output')

        vh = ret.variables = VariableHolder(
            moving_mean=layer.moving_mean,
            mean=layer.moving_mean,  # for backward-compatibility
            moving_variance=layer.moving_variance,
            variance=layer.moving_variance)  # for backward-compatibility
        if scale:
            vh.gamma = layer.gamma
        if center:
            vh.beta = layer.beta
    else:
        red_axis = [0] if ndims == 2 else ([0, 2, 3] if axis == 1 else [0, 1, 2])

        new_shape = None  # don't need to reshape unless ...
        if ndims == 4 and axis == 1:
            new_shape = [1, num_chan, 1, 1]

        batch_mean = tf.reduce_mean(inputs, axis=red_axis)
        batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis)

        if sync_statistics == 'nccl':
            num_dev = ctx.total
            if num_dev == 1:
                logger.warn("BatchNorm(sync_statistics='nccl') is used with only one tower!")
            else:
                assert six.PY2 or TF_version >= (1, 10), \
                    "Cross-GPU BatchNorm is only supported in TF>=1.10 ." \
                    "Upgrade TF or apply this patch manually: https://github.com/tensorflow/tensorflow/pull/20360"

                if TF_version <= (1, 12):
                    try:
                        from tensorflow.contrib.nccl.python.ops.nccl_ops import _validate_and_load_nccl_so
                    except Exception:
                        pass
                    else:
                        _validate_and_load_nccl_so()
                    from tensorflow.contrib.nccl.ops import gen_nccl_ops
                else:
                    from tensorflow.python.ops import gen_nccl_ops
                shared_name = re.sub('tower[0-9]+/', '', tf.get_variable_scope().name)
                batch_mean = gen_nccl_ops.nccl_all_reduce(
                    input=batch_mean,
                    reduction='sum',
                    num_devices=num_dev,
                    shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev)
                batch_mean_square = gen_nccl_ops.nccl_all_reduce(
                    input=batch_mean_square,
                    reduction='sum',
                    num_devices=num_dev,
                    shared_name=shared_name + '_NCCL_mean_square') * (1.0 / num_dev)
        elif sync_statistics == 'horovod':
            # Require https://github.com/uber/horovod/pull/331
            import horovod.tensorflow as hvd
            if hvd.size() == 1:
                logger.warn("BatchNorm(sync_statistics='horovod') is used with only one process!")
            else:
                import horovod
                hvd_version = tuple(map(int, horovod.__version__.split('.')[:3]))
                assert hvd_version >= (0, 13, 6), "sync_statistics=horovod needs horovod>=0.13.6 !"

                batch_mean = hvd.allreduce(batch_mean, average=True)
                batch_mean_square = hvd.allreduce(batch_mean_square, average=True)
        batch_var = batch_mean_square - tf.square(batch_mean)
        batch_mean_vec = batch_mean
        batch_var_vec = batch_var

        beta, gamma, moving_mean, moving_var = get_bn_variables(
            num_chan, scale, center, beta_initializer, gamma_initializer)
        if new_shape is not None:
            batch_mean = tf.reshape(batch_mean, new_shape)
            batch_var = tf.reshape(batch_var, new_shape)
            # Using fused_batch_norm(is_training=False) is actually slightly faster,
            # but hopefully this call will be JITed in the future.
            xn = tf.nn.batch_normalization(
                inputs, batch_mean, batch_var,
                tf.reshape(beta, new_shape),
                tf.reshape(gamma, new_shape), epsilon)
        else:
            xn = tf.nn.batch_normalization(
                inputs, batch_mean, batch_var,
                beta, gamma, epsilon)

        if do_ema_update:
            ret = internal_update_bn_ema(
                xn, batch_mean_vec, batch_var_vec, moving_mean, moving_var, momentum)
        else:
            ret = tf.identity(xn, name='output')

        vh = ret.variables = VariableHolder(
            moving_mean=moving_mean,
            mean=moving_mean,  # for backward-compatibility
            moving_variance=moving_var,
            variance=moving_var)  # for backward-compatibility
        if scale:
            vh.gamma = gamma
        if center:
            vh.beta = beta
    return ret
コード例 #15
0
def c_batch_norm(inputs,
                 scope,
                 training=None,
                 is_main_training_tower=True,
                 axis=None,
                 momentum=0.9,
                 epsilon=1e-5,
                 center=True,
                 scale=True,
                 beta_initializer=tf.zeros_initializer(),
                 gamma_initializer=tf.ones_initializer(),
                 virtual_batch_size=None,
                 data_format='NCHW',
                 internal_update=False,
                 sync_statistics='nccl'):
    """
    Almost equivalent to `tf.layers.batch_normalization`, but different (and more powerful)
    in the following:

    1. Accepts an alternative `data_format` option when `axis` is None. For 2D input, this argument will be ignored.
    2. Default value for `momentum` and `epsilon` is different.
    3. Default value for `training` is automatically obtained from tensorpack's `TowerContext`, but can be overwritten.
    4. Support the `internal_update` option, which enables the use of BatchNorm layer inside conditionals.
    5. Support the `sync_statistics` option, which is very useful in small-batch models.

    Args:
        internal_update (bool): if False, add EMA update ops to
            `tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer
            by control dependencies.
            They are very similar in speed, but `internal_update=True` can be used
            when you have conditionals in your model, or when you have multiple networks to train.
        sync_statistics: either None or "nccl". By default (None), it uses statistics of the input tensor to normalize.
            When set to "nccl", this layer must be used under tensorpack multi-gpu trainers,
            and it then uses per-machine (multiple GPU) statistics to normalize.

            This option has no effect when not training.
            The option is also known as "Cross-GPU BatchNorm" as mentioned in https://arxiv.org/abs/1711.07240.

    Variable Names:

    * ``beta``: the bias term. Will be zero-inited by default.
    * ``gamma``: the scale term. Will be one-inited by default. Input will be transformed by ``x * gamma + beta``.
    * ``mean/EMA``: the moving average of mean.
    * ``variance/EMA``: the moving average of variance.

    Note:
        1. Combinations of ``training`` and ``ctx.is_training``:
            * ``training == ctx.is_training``: standard BN, EMA are
                maintained during training and used during inference. This is
                the default.
            * ``training and not ctx.is_training``: still use batch statistics in inference.
            * ``not training and ctx.is_training``: use EMA to normalize in
                training. This is useful when you load a pre-trained BN and
                don't want to fine tune the EMA. EMA will not be updated in
                this case.
    """
    # parse shapes

    shape = inputs.get_shape().as_list()
    ndims = len(shape)
    assert ndims in [2, 4], ndims
    if sync_statistics is not None:
        sync_statistics = sync_statistics.lower()
    assert sync_statistics in [None, 'nccl', 'horovod'], sync_statistics

    if axis is None:
        if ndims == 2:
            data_format = 'NHWC'
            axis = 1
        else:
            axis = 1 if data_format == 'NCHW' else 3
    else:
        data_format = 'NCHW' if axis == 1 else 'NHWC'
    num_chan = shape[axis]

    if sync_statistics is None:

        raise ValueError
    else:
        red_axis = [0] if ndims == 2 else (
            [0, 2, 3] if axis == 1 else [0, 1, 2])

        new_shape = None
        if ndims == 4 and axis == 1:
            new_shape = [1, num_chan, 1, 1]

        batch_mean = tf.reduce_mean(inputs, axis=red_axis)
        batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis)
        # for debuging cgbn
        # tower_number = is_main_training_tower
        #is_main_training_tower = (is_main_training_tower == 0)
        # batch_mean =tf.Print(batch_mean, [batch_mean], 'batch_norm_mean %s' %tower_number)
        # batch_mean_square =tf.Print(batch_mean_square, [batch_mean_square], 'batch_norm_var %s' %tower_number)

        if sync_statistics == 'nccl':
            if six.PY3 and is_main_training_tower:
                logging.warn(
                    "A TensorFlow bug will cause cross-GPU BatchNorm to fail. "
                    "Apply this patch: https://github.com/tensorflow/tensorflow/pull/20360"
                )

            from tensorflow.contrib.nccl.ops import gen_nccl_ops
            with tf.variable_scope(scope):
                shared_name = re.sub('tower[0-9]+/', '',
                                     tf.get_variable_scope().name)
            num_dev = 4
            batch_mean = gen_nccl_ops.nccl_all_reduce(
                input=batch_mean,
                reduction='sum',
                num_devices=num_dev,
                shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev)
            batch_mean_square = gen_nccl_ops.nccl_all_reduce(
                input=batch_mean_square,
                reduction='sum',
                num_devices=num_dev,
                shared_name=shared_name + '_NCCL_mean_square') * (1.0 /
                                                                  num_dev)
            # if is_main_training_tower:
            #     batch_mean=tf.Print(batch_mean, [batch_mean], 'batch_norm_mean' )
            #     batch_mean_square =tf.Print(batch_mean_square, [batch_mean_square], 'batch_norm_var')

        elif sync_statistics == 'horovod':
            # Require https://github.com/uber/horovod/pull/331
            # Proof-of-concept, not ready yet.
            import horovod.tensorflow as hvd
            batch_mean = hvd.allreduce(batch_mean, average=True)
            batch_mean_square = hvd.allreduce(batch_mean_square, average=True)
        batch_var = batch_mean_square - tf.square(batch_mean)
        batch_mean_vec = batch_mean
        batch_var_vec = batch_var

        beta, gamma, moving_mean, moving_var = get_bn_variables(
            num_chan, scale, center, beta_initializer, gamma_initializer)
        if new_shape is not None:
            batch_mean = tf.reshape(batch_mean, new_shape)
            batch_var = tf.reshape(batch_var, new_shape)
            r_gamma = tf.reshape(gamma, new_shape)
            r_beta = tf.reshape(beta, new_shape)
        else:
            r_gamma, r_beta = gamma, beta
        xn = tf.nn.batch_normalization(inputs, batch_mean, batch_var, r_beta,
                                       r_gamma, epsilon)
        if is_main_training_tower:
            ret = update_bn_ema(xn, batch_mean_vec, batch_var_vec, moving_mean,
                                moving_var, momentum, internal_update)
        else:
            ret = tf.identity(xn, name='output')
    return ret
コード例 #16
0
ファイル: tf_util.py プロジェクト: zouwenqin/GADH_Net_EA
def sync_batch_norm(inputs,
                    is_training=True,
                    scope=None,
                    red_axises=[0, 1, 2],
                    bn_decay=0.999,
                    epsilon=0.001,
                    activation_fn=None,
                    updates_collections=tf.GraphKeys.UPDATE_OPS,
                    reuse=None,
                    variables_collections=None,
                    is_trainable=True,
                    num_dev=3):
    '''
    num_dev is how many gpus you use.
    '''
    # red_axises = [0, 1, 2]
    num_outputs = inputs.get_shape().as_list()[-1]

    if scope is None:
        scope = 'BatchNorm'

    with tf.variable_scope(scope, 'BatchNorm', reuse=reuse):

        gamma = tf.get_variable(name='gamma',
                                shape=[num_outputs],
                                dtype=tf.float32,
                                initializer=tf.constant_initializer(1.0),
                                trainable=is_trainable,
                                collections=variables_collections)

        beta = tf.get_variable(name='beta',
                               shape=[num_outputs],
                               dtype=tf.float32,
                               initializer=tf.constant_initializer(0.0),
                               trainable=is_trainable,
                               collections=variables_collections)

        moving_mean = tf.get_variable(name='moving_mean',
                                      shape=[num_outputs],
                                      dtype=tf.float32,
                                      initializer=tf.constant_initializer(0.0),
                                      trainable=False,
                                      collections=variables_collections)

        moving_var = tf.get_variable(name='moving_variance',
                                     shape=[num_outputs],
                                     dtype=tf.float32,
                                     initializer=tf.constant_initializer(1.0),
                                     trainable=False,
                                     collections=variables_collections)

        if is_training is not None and is_trainable is not None:
            if num_dev == 1:
                mean, var = tf.nn.moments(inputs, red_axises)
            else:
                shared_name = tf.get_variable_scope().name
                batch_mean = tf.reduce_mean(inputs, axis=red_axises)
                batch_mean_square = tf.reduce_mean(tf.square(inputs),
                                                   axis=red_axises)
                batch_mean = gen_nccl_ops.nccl_all_reduce(
                    input=batch_mean,
                    reduction='sum',
                    num_devices=num_dev,
                    shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev)
                batch_mean_square = gen_nccl_ops.nccl_all_reduce(
                    input=batch_mean_square,
                    reduction='sum',
                    num_devices=num_dev,
                    shared_name=shared_name + '_NCCL_mean_square') * (1.0 /
                                                                      num_dev)
                mean = batch_mean
                var = batch_mean_square - tf.square(batch_mean)
            outputs = tf.nn.batch_normalization(inputs, mean, var, beta, gamma,
                                                epsilon)

            if int(outputs.device[-1]) == 0:
                update_moving_mean_op = tf.assign(
                    moving_mean,
                    moving_mean * bn_decay + mean * (1 - bn_decay))
                update_moving_var_op = tf.assign(
                    moving_var, moving_var * bn_decay + var * (1 - bn_decay))
                add_model_variable(moving_mean)
                add_model_variable(moving_var)

                if updates_collections is None:
                    with tf.control_dependencies(
                        [update_moving_mean_op, update_moving_var_op]):
                        outputs = tf.identity(outputs)
                else:
                    tf.add_to_collections(updates_collections,
                                          update_moving_mean_op)
                    tf.add_to_collections(updates_collections,
                                          update_moving_var_op)
                    outputs = tf.identity(outputs)
            else:
                outputs = tf.identity(outputs)

        else:
            outputs, _, _ = tf.nn.fused_batch_norm(inputs,
                                                   gamma,
                                                   beta,
                                                   mean=moving_mean,
                                                   variance=moving_var,
                                                   epsilon=epsilon,
                                                   is_training=False)

        #if activation_fn is not None:
        #   outputs = activation_fn(outputs)
    return outputs
コード例 #17
0
    def moments(
            x,
            axes,
            tower_config,
            shift=None,  # pylint: disable=unused-argument
            name=None,
            keep_dims=False):
        """Calculate the mean and variance of `x`.
        The mean and variance are calculated by aggregating the contents of `x`
        across `axes`.  If `x` is 1-D and `axes = [0]` this is just the mean
        and variance of a vector.
        Note: shift is currently not used; the true mean is computed and used.
        When using these moments for batch normalization (see
        `tf.nn.batch_normalization`):
         * for so-called "global normalization", used with convolutional filters with
           shape `[batch, height, width, depth]`, pass `axes=[0, 1, 2]`.
         * for simple batch normalization pass `axes=[0]` (batch only).
        Args:
          x: A `Tensor`.
          axes: Array of ints.  Axes along which to compute mean and
            variance.
          shift: Not used in the current implementation
          name: Name used to scope the operations that compute the moments.
          keep_dims: produce moments with the same dimensionality as the input.
        Returns:
          Two `Tensor` objects: `mean` and `variance`.
        """
        with ops.name_scope(name, "moments", [x, axes]):
            nccl_name = "NCCL" if not tower_config.is_test else "NCCL_TEST"

            # The dynamic range of fp16 is too limited to support the collection of
            # sufficient statistics. As a workaround we simply perform the operations
            # on 32-bit floats before converting the mean and variance back to fp16
            y = math_ops.cast(
                x, dtypes.float32) if x.dtype == dtypes.float16 else x
            # Compute true mean while keeping the dims for proper broadcasting.

            # Original Code: mean = math_ops.reduce_mean(y, axes, keepdims=True, name="mean")

            device_mean = math_ops.reduce_mean(y,
                                               axes,
                                               keepdims=True,
                                               name="mean")

            shared_name = device_mean.name. \
                replace(tower_config.name, tower_config.prefix.format(nccl_name))

            mean = gen_nccl_ops.nccl_all_reduce(
                input=device_mean,
                reduction="sum",
                num_devices=tower_config.num_devices,
                shared_name=shared_name) / (1.0 * tower_config.num_devices)

            # sample variance, not unbiased variance
            # Note: stop_gradient does not change the gradient that gets
            #       backpropagated to the mean from the variance calculation,
            #       because that gradient is zero

            # Original Code: variance = math_ops.reduce_mean(
            #     math_ops.squared_difference(y, array_ops.stop_gradient(mean)),
            #     axes,
            #     keepdims=True,
            #     name="variance")

            device_variance = math_ops.reduce_mean(math_ops.squared_difference(
                y, array_ops.stop_gradient(mean)),
                                                   axes,
                                                   keepdims=True,
                                                   name="variance")

            shared_name = device_variance.name. \
                replace(tower_config.name, tower_config.prefix.format(nccl_name))

            variance = gen_nccl_ops.nccl_all_reduce(
                input=device_variance,
                reduction="sum",
                num_devices=tower_config.num_devices,
                shared_name=shared_name) / (1.0 * tower_config.num_devices)

            if not keep_dims:
                mean = array_ops.squeeze(mean, axes)
                variance = array_ops.squeeze(variance, axes)
            if x.dtype == dtypes.float16:
                return (math_ops.cast(mean, dtypes.float16),
                        math_ops.cast(variance, dtypes.float16))
            else:
                return (mean, variance)
コード例 #18
0
    def weighted_moments(x,
                         axes,
                         frequency_weights,
                         tower_config,
                         name=None,
                         keep_dims=False):
        """Returns the frequency-weighted mean and variance of `x`.
        Args:
          x: A tensor.
          axes: 1-d tensor of int32 values; these are the axes along which
            to compute mean and variance.
          frequency_weights: A tensor of positive weights which can be
            broadcast with x.
          name: Name used to scope the operation.
          keep_dims: Produce moments with the same dimensionality as the input.
        Returns:
          Two tensors: `weighted_mean` and `weighted_variance`.
        """
        with ops.name_scope(name, "weighted_moments",
                            [x, frequency_weights, axes]):
            x = ops.convert_to_tensor(x, name="x")
            frequency_weights = ops.convert_to_tensor(frequency_weights,
                                                      name="frequency_weights")

            # Unlike moments(), this just uses a simpler two-pass method.

            # See comment in moments() WRT precision; it applies here too.
            needs_cast = x.dtype == dtypes.float16
            if needs_cast:
                x = math_ops.cast(x, dtypes.float32)

            if frequency_weights.dtype != x.dtype:
                frequency_weights = math_ops.cast(frequency_weights, x.dtype)

            # Note that we use keep_dims=True for our reductions regardless of the arg;
            # this is so that the results remain broadcast-compatible with the inputs.

            # Original Code: weighted_input_sum = math_ops.reduce_sum(
            #     frequency_weights * x, axes, name="weighted_input_sum", keepdims=True)

            nccl_name = "NCCL" if not tower_config.is_test else "NCCL_TEST"
            shared_name = tf.get_variable_scope().name. \
                replace(tower_config.name, tower_config.prefix.format(nccl_name))
            device_weighted_input_sum = math_ops.reduce_sum(
                frequency_weights * x,
                axes,
                name="weighted_input_sum",
                keepdims=True)
            weighted_input_sum = gen_nccl_ops.nccl_all_reduce(
                input=device_weighted_input_sum,
                reduction="sum",
                num_devices=tower_config.num_devices,
                shared_name=shared_name) / (1.0 * tower_config.num_devices)

            # The shape of the weights isn't necessarily the same as x's
            # shape, just broadcast-compatible with it -- so this expression
            # performs broadcasting to give a per-item weight, with the same
            # shape as (freqency_weights * x). This avoids having to reason
            # through all the broadcast logic to compute a correct
            # sum_of_weights.
            broadcasted_weights = frequency_weights + array_ops.zeros_like(x)

            sum_of_weights = math_ops.reduce_sum(broadcasted_weights,
                                                 axes,
                                                 name="sum_of_weights",
                                                 keepdims=True)

            divisor = math_ops.reciprocal(sum_of_weights,
                                          name="inv_weight_sum")

            weighted_mean = math_ops.multiply(weighted_input_sum, divisor)

            # Have the weighted mean; now on to variance:
            # weighted_distsq = math_ops.reduce_sum(
            #     frequency_weights * math_ops.squared_difference(x, weighted_mean),
            #     axes,
            #     name="weighted_distsq",
            #     keepdims=True)

            nccl_name = "NCCL" if not tower_config.is_test else "NCCL_TEST"
            shared_name = tf.get_variable_scope().name. \
                replace(tower_config.name, tower_config.prefix.format(nccl_name))
            device_weighted_distsq = math_ops.reduce_sum(
                frequency_weights *
                math_ops.squared_difference(x, weighted_mean),
                axes,
                name="weighted_distsq",
                keepdims=True)
            weighted_distsq = gen_nccl_ops.nccl_all_reduce(
                input=device_weighted_distsq,
                reduction="sum",
                num_devices=tower_config.num_devices,
                shared_name=shared_name) / (1.0 * tower_config.num_devices)

            weighted_variance = math_ops.multiply(weighted_distsq, divisor)

            if not keep_dims:
                weighted_mean = array_ops.squeeze(weighted_mean, axis=axes)
                weighted_variance = array_ops.squeeze(weighted_variance,
                                                      axis=axes)

            if needs_cast:
                weighted_mean = math_ops.cast(weighted_mean, dtypes.float16)
                weighted_variance = math_ops.cast(weighted_variance,
                                                  dtypes.float16)

            return weighted_mean, weighted_variance
コード例 #19
0
ファイル: sy_batchnorm.py プロジェクト: sfxz035/Dose-predict
def sync_batch_norm(inputs,
                    decay=0.999,
                    epsilon=0.001,
                    activation_fn=None,
                    updates_collections=tf.GraphKeys.UPDATE_OPS,
                    is_training=True,
                    variables_collections=None,
                    trainable=True,
                    scope=None,
                    num_dev=1):
    '''
    num_dev is how many gpus you use.
    '''

    red_axises = [0, 1, 2]
    num_outputs = inputs.get_shape().as_list()[-1]

    # if scope is None:
    #     scope = inputs.name.split(':')[0].replace(tf.get_variable_scope().name, '') + '/BatchNorm'
    # print(inputs.name, tf.get_variable_scope().name, scope)
    with variable_scope.variable_scope(scope, 'BatchNorm', [inputs]):
        gamma = tf.get_variable(name='gamma', shape=[num_outputs], dtype=tf.float32,
                                initializer=tf.constant_initializer(1.0), trainable=trainable,
                                 collections=[tf.GraphKeys.TRAINABLE_VARIABLES,
                                              tf.GraphKeys.MODEL_VARIABLES,
                                              tf.GraphKeys.GLOBAL_VARIABLES])
        beta = tf.get_variable(name='beta', shape=[num_outputs], dtype=tf.float32,
                               initializer=tf.constant_initializer(0.0), trainable=trainable,
                               collections=[tf.GraphKeys.TRAINABLE_VARIABLES,
                                            tf.GraphKeys.MODEL_VARIABLES,
                                            tf.GraphKeys.GLOBAL_VARIABLES])
        # print(gamma.name)
        moving_mean = tf.get_variable(name='moving_mean', shape=[num_outputs], dtype=tf.float32,
                                      initializer=tf.constant_initializer(0.0), trainable=False,
                                      collections=[tf.GraphKeys.MODEL_VARIABLES,
                                                   tf.GraphKeys.GLOBAL_VARIABLES])

        moving_var = tf.get_variable(name='moving_variance', shape=[num_outputs], dtype=tf.float32,
                                     initializer=tf.constant_initializer(1.0), trainable=False,
                                     collections=[tf.GraphKeys.MODEL_VARIABLES,
                                                  tf.GraphKeys.GLOBAL_VARIABLES])

        if is_training and trainable:

            if num_dev == 1:
                mean, var = tf.nn.moments(inputs, red_axises)
            else:
                shared_name = re.sub('Model[0-9]+/', '', tf.get_variable_scope().name)

                # print('shared name', shared_name)
                batch_mean = tf.reduce_mean(inputs, axis=red_axises)
                batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axises)
                batch_mean = gen_nccl_ops.nccl_all_reduce(
                    input=batch_mean,
                    reduction='sum',
                    num_devices=num_dev,
                    shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev)
                batch_mean_square = gen_nccl_ops.nccl_all_reduce(
                    input=batch_mean_square,
                    reduction='sum',
                    num_devices=num_dev,
                    shared_name=shared_name + '_NCCL_mean_square') * (1.0 / num_dev)
                mean = batch_mean
                var = batch_mean_square - tf.square(batch_mean)
            outputs = tf.nn.batch_normalization(inputs, mean, var, beta, gamma, epsilon)
            # print(outputs.device)
            if int(outputs.device[-1]) == 0:
                update_moving_mean_op = tf.assign(moving_mean, moving_mean * decay + mean * (1 - decay))
                update_moving_var_op = tf.assign(moving_var, moving_var * decay + var * (1 - decay))
                # add_model_variable(moving_mean)
                # add_model_variable(moving_var)

                if updates_collections is None:
                    with tf.control_dependencies([update_moving_mean_op, update_moving_var_op]):
                        outputs = tf.identity(outputs)
                else:
                    tf.add_to_collections(updates_collections, update_moving_mean_op)
                    tf.add_to_collections(updates_collections, update_moving_var_op)
                    outputs = tf.identity(outputs)
            else:
                outputs = tf.identity(outputs)

        else:
            outputs, _, _ = tf.nn.fused_batch_norm(inputs, gamma, beta, mean=moving_mean, variance=moving_var,
                                                   epsilon=epsilon, is_training=False)

        if activation_fn is not None:
            outputs = activation_fn(outputs)

        return outputs
コード例 #20
0
ファイル: batch_norm.py プロジェクト: quanlzheng/tensorpack
def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
              center=True, scale=True,
              beta_initializer=tf.zeros_initializer(),
              gamma_initializer=tf.ones_initializer(),
              virtual_batch_size=None,
              data_format='channels_last',
              internal_update=False,
              sync_statistics=None):
    """
    Almost equivalent to `tf.layers.batch_normalization`, but different (and more powerful)
    in the following:

    1. Accepts an alternative `data_format` option when `axis` is None. For 2D input, this argument will be ignored.
    2. Default value for `momentum` and `epsilon` is different.
    3. Default value for `training` is automatically obtained from tensorpack's `TowerContext`, but can be overwritten.
    4. Support the `internal_update` option, which enables the use of BatchNorm layer inside conditionals.
    5. Support the `sync_statistics` option, which is very useful in small-batch models.

    Args:
        internal_update (bool): if False, add EMA update ops to
          `tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer by control dependencies.
          They are very similar in speed, but `internal_update=True` can be used
          when you have conditionals in your model, or when you have multiple networks to train.
          Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/14699
        sync_statistics (str or None): one of None, "nccl", or "horovod".

          By default (None), it uses statistics of the input tensor to normalize.
          This is the standard way BatchNorm was done in most frameworks.

          When set to "nccl", this layer must be used under tensorpack's multi-GPU trainers.
          It uses the aggregated statistics of the whole batch (across all GPUs) to normalize.

          When set to "horovod", this layer must be used under tensorpack's :class:`HorovodTrainer`.
          It uses the aggregated statistics of the whole batch (across all MPI ranks) to normalize.
          Note that on single machine this is significantly slower than the "nccl" implementation.

          This implementation averages the per-GPU E[x] and E[x^2] among GPUs to compute
          global mean & variance. Therefore each GPU needs to have the same batch size.

          This option has no effect when not training.

          This option is also known as "Cross-GPU BatchNorm" as mentioned in:
          `MegDet: A Large Mini-Batch Object Detector <https://arxiv.org/abs/1711.07240>`_.
          Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/18222.

    Variable Names:

    * ``beta``: the bias term. Will be zero-inited by default.
    * ``gamma``: the scale term. Will be one-inited by default.
    * ``mean/EMA``: the moving average of mean.
    * ``variance/EMA``: the moving average of variance.

    Note:
        Combinations of ``training`` and ``ctx.is_training``:

        * ``training == ctx.is_training``: standard BN, EMA are maintained during training
          and used during inference. This is the default.
        * ``training and not ctx.is_training``: still use batch statistics in inference.
        * ``not training and ctx.is_training``: use EMA to normalize in
          training. This is useful when you load a pre-trained BN and
          don't want to fine tune the EMA. EMA will not be updated in
          this case.
    """
    # parse shapes
    data_format = get_data_format(data_format, tfmode=False)
    shape = inputs.get_shape().as_list()
    ndims = len(shape)
    assert ndims in [2, 4], ndims
    if sync_statistics is not None:
        sync_statistics = sync_statistics.lower()
    assert sync_statistics in [None, 'nccl', 'horovod'], sync_statistics

    if axis is None:
        if ndims == 2:
            data_format = 'NHWC'
            axis = 1
        else:
            axis = 1 if data_format == 'NCHW' else 3
    else:
        data_format = 'NCHW' if axis == 1 else 'NHWC'
    num_chan = shape[axis]

    # parse training/ctx
    ctx = get_current_tower_context()
    if training is None:
        training = ctx.is_training
    training = bool(training)
    TF_version = get_tf_version_tuple()
    freeze_bn_backward = not training and ctx.is_training
    if freeze_bn_backward:
        assert TF_version >= (1, 4), \
            "Fine tuning a BatchNorm model with fixed statistics needs TF>=1.4!"
        if ctx.is_main_training_tower:  # only warn in first tower
            logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.")
        # Using moving_mean/moving_variance in training, which means we
        # loaded a pre-trained BN and only fine-tuning the affine part.

    if sync_statistics is None or not (training and ctx.is_training):
        coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
        with rename_get_variable(
                {'moving_mean': 'mean/EMA',
                    'moving_variance': 'variance/EMA'}):
            tf_args = dict(
                axis=axis,
                momentum=momentum, epsilon=epsilon,
                center=center, scale=scale,
                beta_initializer=beta_initializer,
                gamma_initializer=gamma_initializer,
                # https://github.com/tensorflow/tensorflow/issues/10857#issuecomment-410185429
                fused=(ndims == 4 and axis in [1, 3] and not freeze_bn_backward),
                _reuse=tf.get_variable_scope().reuse)
            if TF_version >= (1, 5):
                tf_args['virtual_batch_size'] = virtual_batch_size
            else:
                assert virtual_batch_size is None, "Feature not supported in this version of TF!"
            layer = tf.layers.BatchNormalization(**tf_args)
            xn = layer.apply(inputs, training=training, scope=tf.get_variable_scope())

        # maintain EMA only on one GPU is OK, even in replicated mode.
        # because during training, EMA isn't used
        if ctx.is_main_training_tower:
            for v in layer.non_trainable_variables:
                if isinstance(v, tf.Variable):
                    tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)
        if not ctx.is_main_training_tower or internal_update:
            restore_collection(coll_bk)

        if training and internal_update:
            assert layer.updates
            with tf.control_dependencies(layer.updates):
                ret = tf.identity(xn, name='output')
        else:
            ret = tf.identity(xn, name='output')

        vh = ret.variables = VariableHolder(
            moving_mean=layer.moving_mean,
            mean=layer.moving_mean,  # for backward-compatibility
            moving_variance=layer.moving_variance,
            variance=layer.moving_variance)  # for backward-compatibility
        if scale:
            vh.gamma = layer.gamma
        if center:
            vh.beta = layer.beta
    else:
        red_axis = [0] if ndims == 2 else ([0, 2, 3] if axis == 1 else [0, 1, 2])

        new_shape = None  # don't need to reshape unless ...
        if ndims == 4 and axis == 1:
            new_shape = [1, num_chan, 1, 1]

        batch_mean = tf.reduce_mean(inputs, axis=red_axis)
        batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis)

        if sync_statistics == 'nccl':
            num_dev = ctx.total
            if num_dev == 1:
                logger.warn("BatchNorm(sync_statistics='nccl') is used with only one tower!")
            else:
                assert six.PY2 or TF_version >= (1, 10), \
                    "Cross-GPU BatchNorm is only supported in TF>=1.10 ." \
                    "Upgrade TF or apply this patch manually: https://github.com/tensorflow/tensorflow/pull/20360"

                from tensorflow.contrib.nccl.ops import gen_nccl_ops
                shared_name = re.sub('tower[0-9]+/', '', tf.get_variable_scope().name)
                batch_mean = gen_nccl_ops.nccl_all_reduce(
                    input=batch_mean,
                    reduction='sum',
                    num_devices=num_dev,
                    shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev)
                batch_mean_square = gen_nccl_ops.nccl_all_reduce(
                    input=batch_mean_square,
                    reduction='sum',
                    num_devices=num_dev,
                    shared_name=shared_name + '_NCCL_mean_square') * (1.0 / num_dev)
        elif sync_statistics == 'horovod':
            # Require https://github.com/uber/horovod/pull/331
            import horovod.tensorflow as hvd
            if hvd.size() == 1:
                logger.warn("BatchNorm(sync_statistics='horovod') is used with only one process!")
            else:
                import horovod
                hvd_version = tuple(map(int, horovod.__version__.split('.')))
                assert hvd_version >= (0, 13, 6), "sync_statistics=horovod needs horovod>=0.13.6 !"

                batch_mean = hvd.allreduce(batch_mean, average=True)
                batch_mean_square = hvd.allreduce(batch_mean_square, average=True)
        batch_var = batch_mean_square - tf.square(batch_mean)
        batch_mean_vec = batch_mean
        batch_var_vec = batch_var

        beta, gamma, moving_mean, moving_var = get_bn_variables(
            num_chan, scale, center, beta_initializer, gamma_initializer)
        if new_shape is not None:
            batch_mean = tf.reshape(batch_mean, new_shape)
            batch_var = tf.reshape(batch_var, new_shape)
            # Using fused_batch_norm(is_training=False) is actually slightly faster,
            # but hopefully this call will be JITed in the future.
            xn = tf.nn.batch_normalization(
                inputs, batch_mean, batch_var,
                tf.reshape(beta, new_shape),
                tf.reshape(gamma, new_shape), epsilon)
        else:
            xn = tf.nn.batch_normalization(
                inputs, batch_mean, batch_var,
                beta, gamma, epsilon)

        if ctx.is_main_training_tower:
            ret = update_bn_ema(
                xn, batch_mean_vec, batch_var_vec, moving_mean, moving_var,
                momentum, internal_update)
        else:
            ret = tf.identity(xn, name='output')

        vh = ret.variables = VariableHolder(
            moving_mean=moving_mean,
            mean=moving_mean,  # for backward-compatibility
            moving_variance=moving_var,
            variance=moving_var)  # for backward-compatibility
        if scale:
            vh.gamma = gamma
        if center:
            vh.beta = beta
    return ret
コード例 #21
0
def sync_batch_norm(inputs,
                    decay=0.999,
                    axis=-1,
                    epsilon=0.001,
                    activation_fn=None,
                    updates_collections=tf.GraphKeys.UPDATE_OPS,
                    is_training=True,
                    reuse=None,
                    variables_collections=None,
                    trainable=True,
                    scope=None,
                    num_dev=1):
    '''
		num_dev is how many gpus you use.
		this function is from https://github.com/jianlong-yuan/syncbn-tensorflow/blob/master/syncbn.py
	'''
    # shape of inputs is [batch, height, width, depth]
    num_outputs = inputs.get_shape().as_list()[-1]
    # print (f"num_outputs = {num_outputs}")	# 3

    if scope is None:
        scope = 'batch_normalization'

    with tf.variable_scope(scope, reuse=reuse):
        # initializer, gamma and beta is trainable, moving_mean and moving_var is not
        gamma = tf.get_variable(name='gamma',
                                shape=[num_outputs],
                                dtype=tf.float32,
                                initializer=tf.constant_initializer(1.0),
                                trainable=trainable,
                                collections=variables_collections)

        beta = tf.get_variable(name='beta',
                               shape=[num_outputs],
                               dtype=tf.float32,
                               initializer=tf.constant_initializer(0.0),
                               trainable=trainable,
                               collections=variables_collections)

        moving_mean = tf.get_variable(name='moving_mean',
                                      shape=[num_outputs],
                                      dtype=tf.float32,
                                      initializer=tf.constant_initializer(0.0),
                                      trainable=False,
                                      collections=variables_collections)

        moving_var = tf.get_variable(name='moving_variance',
                                     shape=[num_outputs],
                                     dtype=tf.float32,
                                     initializer=tf.constant_initializer(1.0),
                                     trainable=False,
                                     collections=variables_collections)

        # is_training and trainable is logical and
        # this is same with [math_ops.logical_and())]
        # (https://github.com/tensorflow/tensorflow/blob/
        # 508f76b1d9925304cedd56d51480ec380636cb82/tensorflow/
        # python/keras/layers/normalization.py#L621)
        if is_training and trainable:
            # only one GPU
            if num_dev == 1:
                mean, var = tf.nn.moments(inputs, axes=axis)
            # multi GPUs
            else:
                # avarage moving_mean and moving_var in multi GPUs
                shared_name = tf.get_variable_scope().name
                batch_mean = tf.reduce_mean(inputs, axis=axis)
                batch_mean = gen_nccl_ops.nccl_all_reduce(
                    input=batch_mean,
                    reduction='sum',
                    num_devices=num_dev,
                    shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev)
                batch_mean_square = tf.reduce_mean(tf.square(inputs),
                                                   axis=axis)
                batch_mean_square = gen_nccl_ops.nccl_all_reduce(
                    input=batch_mean_square,
                    reduction='sum',
                    num_devices=num_dev,
                    shared_name=shared_name + '_NCCL_mean_square') * (1.0 /
                                                                      num_dev)
                mean = batch_mean
                var = batch_mean_square - tf.square(batch_mean)
            outputs = tf.nn.batch_normalization(inputs, mean, var, beta, gamma,
                                                epsilon)

            # print (outputs.device)	# /device:GPU:1

            # those code block is executed in every GPUs
            # just assign moving_mean and moving_var in GPU:0
            if int(outputs.device[-1]) == 0:
                update_moving_mean_op = tf.assign(
                    moving_mean, moving_mean * decay + mean * (1 - decay))
                update_moving_var_op = tf.assign(
                    moving_var, moving_var * decay + var * (1 - decay))
                add_model_variable(moving_mean)
                add_model_variable(moving_var)

                if updates_collections is None:
                    with tf.control_dependencies(
                        [update_moving_mean_op, update_moving_var_op]):
                        outputs = tf.identity(outputs)
                else:
                    tf.add_to_collections(updates_collections,
                                          update_moving_mean_op)
                    tf.add_to_collections(updates_collections,
                                          update_moving_var_op)
                    outputs = tf.identity(outputs)
            else:
                outputs = tf.identity(outputs)
        else:
            outputs, _, _ = tf.nn.fused_batch_norm(inputs,
                                                   gamma,
                                                   beta,
                                                   mean=moving_mean,
                                                   variance=moving_var,
                                                   epsilon=epsilon,
                                                   is_training=False)

        if activation_fn is not None:
            outputs = activation_fn(outputs)

        return outputs
コード例 #22
0
ファイル: batch_norm.py プロジェクト: yt-oh96/tensorpack
def BatchNorm(inputs,
              axis=None,
              training=None,
              momentum=0.9,
              epsilon=1e-5,
              center=True,
              scale=True,
              beta_initializer=tf.zeros_initializer(),
              gamma_initializer=tf.ones_initializer(),
              virtual_batch_size=None,
              data_format='channels_last',
              internal_update=False,
              sync_statistics=None):
    """
    Almost equivalent to `tf.layers.batch_normalization`, but different (and more powerful)
    in the following:

    1. Accepts an alternative `data_format` option when `axis` is None. For 2D input, this argument will be ignored.
    2. Default value for `momentum` and `epsilon` is different.
    3. Default value for `training` is automatically obtained from tensorpack's `TowerContext`, but can be overwritten.
    4. Support the `internal_update` option, which enables the use of BatchNorm layer inside conditionals.
    5. Support the `sync_statistics` option, which is very useful in small-batch models.

    Args:
        internal_update (bool): if False, add EMA update ops to
          `tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer by control dependencies.
          They are very similar in speed, but `internal_update=True` can be used
          when you have conditionals in your model, or when you have multiple networks to train.
          Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/14699
        sync_statistics (str or None): one of None "nccl", or "horovod".

          By default (None), it uses statistics of the input tensor to normalize.
          This is the standard way BatchNorm was done in most frameworks.

          When set to "nccl", this layer must be used under tensorpack's multi-GPU trainers.
          It uses the aggregated statistics of the whole batch (across all GPUs) to normalize.

          When set to "horovod", this layer must be used under tensorpack's :class:`HorovodTrainer`.
          It uses the aggregated statistics of the whole batch (across all MPI ranks) to normalize.
          Note that on single machine this is significantly slower than the "nccl" implementation.

          This implementation averages the per-GPU E[x] and E[x^2] among GPUs to compute
          global mean & variance. Therefore each GPU needs to have the same batch size.

          This option has no effect when not training.

          This option is also known as "Cross-GPU BatchNorm" as mentioned in:
          `MegDet: A Large Mini-Batch Object Detector <https://arxiv.org/abs/1711.07240>`_.
          Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/18222.

    Variable Names:

    * ``beta``: the bias term. Will be zero-inited by default.
    * ``gamma``: the scale term. Will be one-inited by default.
    * ``mean/EMA``: the moving average of mean.
    * ``variance/EMA``: the moving average of variance.

    Note:
        Combinations of ``training`` and ``ctx.is_training``:

        * ``training == ctx.is_training``: standard BN, EMA are maintained during training
          and used during inference. This is the default.
        * ``training and not ctx.is_training``: still use batch statistics in inference.
        * ``not training and ctx.is_training``: use EMA to normalize in
          training. This is useful when you load a pre-trained BN and
          don't want to fine tune the EMA. EMA will not be updated in
          this case.
    """
    # parse shapes
    data_format = get_data_format(data_format, tfmode=False)
    shape = inputs.get_shape().as_list()
    ndims = len(shape)
    assert ndims in [2, 4], ndims
    if sync_statistics is not None:
        sync_statistics = sync_statistics.lower()
    assert sync_statistics in [None, 'nccl', 'horovod'], sync_statistics

    if axis is None:
        if ndims == 2:
            data_format = 'NHWC'
            axis = 1
        else:
            axis = 1 if data_format == 'NCHW' else 3
    else:
        data_format = 'NCHW' if axis == 1 else 'NHWC'
    num_chan = shape[axis]

    # parse training/ctx
    ctx = get_current_tower_context()
    if training is None:
        training = ctx.is_training
    training = bool(training)
    TF_version = get_tf_version_tuple()
    if not training and ctx.is_training:
        assert TF_version >= (1, 4), \
            "Fine tuning a BatchNorm model with fixed statistics is only " \
            "supported after https://github.com/tensorflow/tensorflow/pull/12580 "
        if ctx.is_main_training_tower:  # only warn in first tower
            logger.warn(
                "[BatchNorm] Using moving_mean/moving_variance in training.")
        # Using moving_mean/moving_variance in training, which means we
        # loaded a pre-trained BN and only fine-tuning the affine part.

    if sync_statistics is None or not (training and ctx.is_training):
        coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
        with rename_get_variable({
                'moving_mean': 'mean/EMA',
                'moving_variance': 'variance/EMA'
        }):
            tf_args = dict(axis=axis,
                           momentum=momentum,
                           epsilon=epsilon,
                           center=center,
                           scale=scale,
                           beta_initializer=beta_initializer,
                           gamma_initializer=gamma_initializer,
                           fused=(ndims == 4 and axis in [1, 3]),
                           _reuse=tf.get_variable_scope().reuse)
            if TF_version >= (1, 5):
                tf_args['virtual_batch_size'] = virtual_batch_size
            else:
                assert virtual_batch_size is None, "Feature not supported in this version of TF!"
            layer = tf.layers.BatchNormalization(**tf_args)
            xn = layer.apply(inputs,
                             training=training,
                             scope=tf.get_variable_scope())

        # maintain EMA only on one GPU is OK, even in replicated mode.
        # because during training, EMA isn't used
        if ctx.is_main_training_tower:
            for v in layer.non_trainable_variables:
                add_model_variable(v)
        if not ctx.is_main_training_tower or internal_update:
            restore_collection(coll_bk)

        if training and internal_update:
            assert layer.updates
            with tf.control_dependencies(layer.updates):
                ret = tf.identity(xn, name='output')
        else:
            ret = tf.identity(xn, name='output')

        vh = ret.variables = VariableHolder(
            moving_mean=layer.moving_mean,
            mean=layer.moving_mean,  # for backward-compatibility
            moving_variance=layer.moving_variance,
            variance=layer.moving_variance)  # for backward-compatibility
        if scale:
            vh.gamma = layer.gamma
        if center:
            vh.beta = layer.beta
    else:
        red_axis = [0] if ndims == 2 else (
            [0, 2, 3] if axis == 1 else [0, 1, 2])

        new_shape = None  # don't need to reshape unless ...
        if ndims == 4 and axis == 1:
            new_shape = [1, num_chan, 1, 1]

        batch_mean = tf.reduce_mean(inputs, axis=red_axis)
        batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis)

        if sync_statistics == 'nccl':
            if six.PY3 and TF_version <= (1, 9) and ctx.is_main_training_tower:
                logger.warn(
                    "A TensorFlow bug will cause cross-GPU BatchNorm to fail. "
                    "Apply this patch: https://github.com/tensorflow/tensorflow/pull/20360"
                )

            from tensorflow.contrib.nccl.ops import gen_nccl_ops
            shared_name = re.sub('tower[0-9]+/', '',
                                 tf.get_variable_scope().name)
            num_dev = ctx.total
            if num_dev == 1:
                logger.warn(
                    "BatchNorm(sync_statistics='nccl') is used with only one tower!"
                )
            else:
                batch_mean = gen_nccl_ops.nccl_all_reduce(
                    input=batch_mean,
                    reduction='sum',
                    num_devices=num_dev,
                    shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev)
                batch_mean_square = gen_nccl_ops.nccl_all_reduce(
                    input=batch_mean_square,
                    reduction='sum',
                    num_devices=num_dev,
                    shared_name=shared_name + '_NCCL_mean_square') * (1.0 /
                                                                      num_dev)
        elif sync_statistics == 'horovod':
            # Require https://github.com/uber/horovod/pull/331
            import horovod.tensorflow as hvd
            batch_mean = hvd.allreduce(batch_mean, average=True)
            batch_mean_square = hvd.allreduce(batch_mean_square, average=True)
        batch_var = batch_mean_square - tf.square(batch_mean)
        batch_mean_vec = batch_mean
        batch_var_vec = batch_var

        beta, gamma, moving_mean, moving_var = get_bn_variables(
            num_chan, scale, center, beta_initializer, gamma_initializer)
        if new_shape is not None:
            batch_mean = tf.reshape(batch_mean, new_shape)
            batch_var = tf.reshape(batch_var, new_shape)
            # Using fused_batch_norm(is_training=False) is actually slightly faster,
            # but hopefully this call will be JITed in the future.
            xn = tf.nn.batch_normalization(inputs, batch_mean, batch_var,
                                           tf.reshape(beta, new_shape),
                                           tf.reshape(gamma, new_shape),
                                           epsilon)
        else:
            xn = tf.nn.batch_normalization(inputs, batch_mean, batch_var, beta,
                                           gamma, epsilon)

        if ctx.is_main_training_tower:
            ret = update_bn_ema(xn, batch_mean_vec, batch_var_vec, moving_mean,
                                moving_var, momentum, internal_update)
        else:
            ret = tf.identity(xn, name='output')

        vh = ret.variables = VariableHolder(
            moving_mean=moving_mean,
            mean=moving_mean,  # for backward-compatibility
            moving_variance=moving_var,
            variance=moving_var)  # for backward-compatibility
        if scale:
            vh.gamma = gamma
        if center:
            vh.beta = beta
    return ret
コード例 #23
0
ファイル: custom_ops.py プロジェクト: messiah1999/BraTS19
def BatchNorm3d(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
              center=True, scale=True,
              beta_initializer=tf.zeros_initializer(),
              gamma_initializer=tf.ones_initializer(),
              virtual_batch_size=None,
              data_format='channels_last',
              internal_update=False,
              sync_statistics=None):


    data_format = get_data_format(data_format, tfmode=False)
    shape = inputs.get_shape().as_list()
    ndims = len(shape)
    if sync_statistics is not None:
        sync_statistics = sync_statistics.lower()
    assert sync_statistics in [None, 'nccl', 'horovod'], sync_statistics

    if axis is None:
        if ndims == 2:
            data_format = 'NHWC'
            axis = 1
        elif ndims == 5:
            axis = 1 if data_format == 'NCHW' else 4
        else:
            axis = 1 if data_format == 'NCHW' else 3
    else:
        data_format = 'NCHW' if axis == 1 else 'NHWC'
    num_chan = shape[axis]

    ctx = get_current_tower_context()
    if training is None:
        training = ctx.is_training
    training = bool(training)
    TF_version = get_tf_version_tuple()
    if not training and ctx.is_training:
        assert TF_version >= 1.4
        if ctx.is_main_training_tower: 
            logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.")

    if sync_statistics is None or not (training and ctx.is_training):
        coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
        with rename_get_variable(
                {'moving_mean': 'mean/EMA',
                 'moving_variance': 'variance/EMA'}):
            tf_args = dict(
                axis=axis,
                momentum=momentum, epsilon=epsilon,
                center=center, scale=scale,
                beta_initializer=beta_initializer,
                gamma_initializer=gamma_initializer,
                fused=True,
                _reuse=tf.get_variable_scope().reuse)
            if TF_version >= 1.5:
                tf_args['virtual_batch_size'] = virtual_batch_size
            else:
                assert virtual_batch_size is None
            layer = tf.layers.BatchNormalization(**tf_args)
            xn = layer.apply(inputs, training=training, scope=tf.get_variable_scope())

        # maintain EMA only on one GPU
        if ctx.is_main_training_tower:
            for v in layer.non_trainable_variables:
                add_model_variable(v)
        if not ctx.is_main_training_tower or internal_update:
            restore_collection(coll_bk)

        if training and internal_update:
            assert layer.updates
            with tf.control_dependencies(layer.updates):
                ret = tf.identity(xn, name='output')
        else:
            ret = tf.identity(xn, name='output')

        vh = ret.variables = VariableHolder(
            moving_mean=layer.moving_mean,
            mean=layer.moving_mean,  #backward-compatibility
            moving_variance=layer.moving_variance,
            variance=layer.moving_variance)  #backward-compatibility
        if scale:
            vh.gamma = layer.gamma
        if center:
            vh.beta = layer.beta
    else:
        red_axis = [0] if ndims == 2 else ([0, 2, 3] if axis == 1 else [0, 1, 2])
        if ndims == 5:
            red_axis = [0, 2, 3, 4] if axis == 1 else [0, 1, 2, 3]
        new_shape = None 
        if ndims == 4 and axis == 1:
            new_shape = [1, num_chan, 1, 1]
        if ndims == 5 and axis == 1:
            new_shape = [1, num_chan, 1, 1, 1]

        batch_mean = tf.reduce_mean(inputs, axis=red_axis)
        batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis)

        if sync_statistics == 'nccl':
            if six.PY3 and TF_version <= 1.8 and ctx.is_main_training_tower:
                logger.warn("A TensorFlow bug cusing cross-GPU BatchNorm to fail")

            from tensorflow.contrib.nccl.ops import gen_nccl_ops
            shared_name = re.sub('tower[0-9]+/', '', tf.get_variable_scope().name)
            num_dev = ctx.total
            batch_mean = gen_nccl_ops.nccl_all_reduce(
                input=batch_mean,
                reduction='sum',
                num_devices=num_dev,
                shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev)
            batch_mean_square = gen_nccl_ops.nccl_all_reduce(
                input=batch_mean_square,
                reduction='sum',
                num_devices=num_dev,
                shared_name=shared_name + '_NCCL_mean_square') * (1.0 / num_dev)
        elif sync_statistics == 'horovod':
            import horovod.tensorflow as hvd
            batch_mean = hvd.allreduce(batch_mean, average=True)
            batch_mean_square = hvd.allreduce(batch_mean_square, average=True)
        batch_var = batch_mean_square - tf.square(batch_mean)
        batch_mean_vec = batch_mean
        batch_var_vec = batch_var

        beta, gamma, moving_mean, moving_var = get_bn_variables(
            num_chan, scale, center, beta_initializer, gamma_initializer)
        if new_shape is not None:
            batch_mean = tf.reshape(batch_mean, new_shape)
            batch_var = tf.reshape(batch_var, new_shape)
            xn = tf.nn.batch_normalization(
                inputs, batch_mean, batch_var,
                tf.reshape(beta, new_shape),
                tf.reshape(gamma, new_shape), epsilon)
        else:
            xn = tf.nn.batch_normalization(
                inputs, batch_mean, batch_var,
                beta, gamma, epsilon)

        if ctx.is_main_training_tower:
            ret = update_bn_ema(
                xn, batch_mean_vec, batch_var_vec, moving_mean, moving_var,
                momentum, internal_update)
        else:
            ret = tf.identity(xn, name='output')

        vh = ret.variables = VariableHolder(
            moving_mean=moving_mean,
            mean=moving_mean,  
            moving_variance=moving_var,
            variance=moving_var)  
        if scale:
            vh.gamma = gamma
        if center:
            vh.beta = beta
    return ret