コード例 #1
0
ファイル: style2_model.py プロジェクト: swift-n-brutal/syntex
def SphericalAdd(x1,
                 x2,
                 theta_mean=0.,
                 theta_std=0.,
                 use_wscale=True,
                 lrmul=1.,
                 adaptive_lr=True,
                 channelwise=True):
    """y = x1 * cos(theta) + x2 * sin(theta)
    Special cases:
        y = x1 if theta == 0
        y = x2 if theta == np.pi/2
    """
    chan = x1.get_shape().as_list()[-1] if channelwise else 1
    theta = get_bias(chan,
                     base_std=theta_std,
                     use_wscale=use_wscale,
                     lrmul=lrmul,
                     adaptive_lr=adaptive_lr,
                     name="theta")
    vh = VariableHolder(theta=theta)
    theta = theta + theta_mean
    s1 = tf.math.cos(theta, name="s1")
    s2 = tf.math.sin(theta, name="s2")
    ret = tf.identity(tf.add(x1 * s1, x2 * s2), name="output")
    ret.variables = vh
    return ret
コード例 #2
0
def InstanceNorm5d(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format='channels_last'):
    """
    Instance Normalization, as in the paper:
    `Instance Normalization: The Missing Ingredient for Fast Stylization
    <https://arxiv.org/abs/1607.08022>`_.
    Args:
        x (tf.Tensor): a 4D tensor.
        epsilon (float): avoid divide-by-zero
        use_affine (bool): whether to apply learnable affine transformation
    """
    data_format = get_data_format(data_format, keras_mode=True)
    shape = x.get_shape().as_list()
    # assert len(shape) == 4, "Input of InstanceNorm has to be 4D!"
    if len(shape) == 5:
        if data_format == 'NHWC':
            axis = [1, 2, 3]
            ch = shape[4]
            new_shape = [1, 1, 1, 1, ch]
        else:
            axis = [2, 3, 4]
            ch = shape[1]
            new_shape = [1, ch, 1, 1, 1]
    else:
        if data_format == 'NHWC':
            axis = [1, 2]
            ch = shape[3]
            new_shape = [1, 1, 1, ch]
        else:
            axis = [2, 3]
            ch = shape[1]
            new_shape = [1, ch, 1, 1]
    assert ch is not None, "Input of InstanceNorm require known channel!"

    mean, var = tf.nn.moments(x, axis, keep_dims=True)

    if not use_affine:
        return tf.divide(x - mean, tf.sqrt(var + epsilon), name='output')

    beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer())
    beta = tf.reshape(beta, new_shape)
    if gamma_init is None:
        gamma_init = tf.constant_initializer(1.0)
    gamma = tf.get_variable('gamma', [ch], initializer=gamma_init)
    gamma = tf.reshape(gamma, new_shape)
    ret = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output')

    vh = ret.variables = VariableHolder()
    if use_affine:
        vh.gamma = gamma
        vh.beta = beta
    return ret
コード例 #3
0
ファイル: custom_ops.py プロジェクト: messiah1999/BraTS19
def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format='channels_last'):
    data_format = get_data_format(data_format, tfmode=False)
    shape = x.get_shape().as_list()
    if len(shape) == 5:
        if data_format == 'NHWC':
            axis = [1, 2, 3]
            ch = shape[4]
            new_shape = [1, 1, 1, 1, ch]
        else:
            axis = [2, 3, 4]
            ch = shape[1]
            new_shape = [1, ch, 1, 1, 1]
    else:
        if data_format == 'NHWC':
            axis = [1, 2]
            ch = shape[3]
            new_shape = [1, 1, 1, ch]
        else:
            axis = [2, 3]
            ch = shape[1]
            new_shape = [1, ch, 1, 1]
    assert ch is not None,

    mean, var = tf.nn.moments(x, axis, keep_dims=True)

    if not use_affine:
        return tf.divide(x - mean, tf.sqrt(var + epsilon), name='output')

    beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer())
    beta = tf.reshape(beta, new_shape)
    if gamma_init is None:
        gamma_init = tf.constant_initializer(1.0)
    gamma = tf.get_variable('gamma', [ch], initializer=gamma_init)
    gamma = tf.reshape(gamma, new_shape)
    ret = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output')

    vh = ret.variables = VariableHolder()
    if use_affine:
        vh.gamma = gamma
        vh.beta = beta
    return ret
コード例 #4
0
def InstanceNorm5d(x,
                   epsilon=1e-5,
                   use_affine=True,
                   gamma_init=None,
                   data_format='channels_last'):

    shape = x.get_shape().as_list()
    # assert len(shape) == 4, "Input of InstanceNorm has to be 4D!"

    axis = [1, 2, 3]
    ch = shape[4]
    new_shape = [1, 1, 1, 1, ch]

    mean, var = tf.nn.moments(x, axis, keep_dims=True)

    if not use_affine:
        return tf.divide(x - mean, tf.sqrt(var + epsilon), name='output')

    beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer())
    beta = tf.reshape(beta, new_shape)
    if gamma_init is None:
        gamma_init = tf.constant_initializer(1.0)
    gamma = tf.get_variable('gamma', [ch], initializer=gamma_init)
    gamma = tf.reshape(gamma, new_shape)
    ret = tf.nn.batch_normalization(x,
                                    mean,
                                    var,
                                    beta,
                                    gamma,
                                    epsilon,
                                    name='output')

    vh = ret.variables = VariableHolder()
    if use_affine:
        vh.gamma = gamma
        vh.beta = beta
    return ret
コード例 #5
0
def BatchNorm3d(inputs,
                axis=None,
                training=None,
                momentum=0.9,
                epsilon=1e-5,
                center=True,
                scale=True,
                beta_initializer=tf.zeros_initializer(),
                gamma_initializer=tf.ones_initializer(),
                virtual_batch_size=None,
                data_format='channels_last',
                internal_update=False,
                sync_statistics=None):
    """
    Almost equivalent to `tf.layers.batch_normalization`, but different (and more powerful)
    in the following:
    1. Accepts an alternative `data_format` option when `axis` is None. For 2D input, this argument will be ignored.
    2. Default value for `momentum` and `epsilon` is different.
    3. Default value for `training` is automatically obtained from tensorpack's `TowerContext`, but can be overwritten.
    4. Support the `internal_update` option, which enables the use of BatchNorm layer inside conditionals.
    5. Support the `sync_statistics` option, which is very useful in small-batch models.
    Args:
        internal_update (bool): if False, add EMA update ops to
          `tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer by control dependencies.
          They are very similar in speed, but `internal_update=True` can be used
          when you have conditionals in your model, or when you have multiple networks to train.
          Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/14699
        sync_statistics: either None or "nccl". By default (None), it uses statistics of the input tensor to normalize.
          When set to "nccl", this layer must be used under tensorpack multi-gpu trainers,
          and it then uses per-machine (multiple GPU) statistics to normalize.
          Note that this implementation averages the per-tower E[x] and E[x^2] among towers to compute
          global mean&variance. The result is the global mean&variance only if each tower has the same batch size.
          This option has no effect when not training.
          This option is also known as "Cross-GPU BatchNorm" as mentioned in https://arxiv.org/abs/1711.07240.
          Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/18222
    Variable Names:
    * ``beta``: the bias term. Will be zero-inited by default.
    * ``gamma``: the scale term. Will be one-inited by default.
    * ``mean/EMA``: the moving average of mean.
    * ``variance/EMA``: the moving average of variance.
    Note:
        Combinations of ``training`` and ``ctx.is_training``:
        * ``training == ctx.is_training``: standard BN, EMA are maintained during training
          and used during inference. This is the default.
        * ``training and not ctx.is_training``: still use batch statistics in inference.
        * ``not training and ctx.is_training``: use EMA to normalize in
          training. This is useful when you load a pre-trained BN and
          don't want to fine tune the EMA. EMA will not be updated in
          this case.
    """
    # parse shapes
    data_format = get_data_format(data_format, tfmode=False)
    shape = inputs.get_shape().as_list()
    ndims = len(shape)
    # in 3d conv, we have 5d dim [batch, c, d, h, w]
    # assert ndims in [2, 4], ndims
    if sync_statistics is not None:
        sync_statistics = sync_statistics.lower()
    assert sync_statistics in [None, 'nccl', 'horovod'], sync_statistics

    if axis is None:
        if ndims == 2:
            data_format = 'NHWC'
            axis = 1
        elif ndims == 5:
            axis = 1 if data_format == 'NCHW' else 4
        else:
            axis = 1 if data_format == 'NCHW' else 3
    else:
        data_format = 'NCHW' if axis == 1 else 'NHWC'
    num_chan = shape[axis]

    # parse training/ctx
    ctx = get_current_tower_context()
    if training is None:
        training = ctx.is_training
    training = bool(training)
    TF_version = get_tf_version_number()
    if not training and ctx.is_training:
        assert TF_version >= 1.4, \
            "Fine tuning a BatchNorm model with fixed statistics is only " \
            "supported after https://github.com/tensorflow/tensorflow/pull/12580 "
        if ctx.is_main_training_tower:  # only warn in first tower
            logger.warn(
                "[BatchNorm] Using moving_mean/moving_variance in training.")
        # Using moving_mean/moving_variance in training, which means we
        # loaded a pre-trained BN and only fine-tuning the affine part.

    if sync_statistics is None or not (training and ctx.is_training):
        coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
        with rename_get_variable({
                'moving_mean': 'mean/EMA',
                'moving_variance': 'variance/EMA'
        }):
            tf_args = dict(axis=axis,
                           momentum=momentum,
                           epsilon=epsilon,
                           center=center,
                           scale=scale,
                           beta_initializer=beta_initializer,
                           gamma_initializer=gamma_initializer,
                           fused=True,
                           _reuse=tf.get_variable_scope().reuse)
            if TF_version >= 1.5:
                tf_args['virtual_batch_size'] = virtual_batch_size
            else:
                assert virtual_batch_size is None, "Feature not supported in this version of TF!"
            layer = tf.layers.BatchNormalization(**tf_args)
            xn = layer.apply(inputs,
                             training=training,
                             scope=tf.get_variable_scope())

        # maintain EMA only on one GPU is OK, even in replicated mode.
        # because during training, EMA isn't used
        if ctx.is_main_training_tower:
            for v in layer.non_trainable_variables:
                add_model_variable(v)
        if not ctx.is_main_training_tower or internal_update:
            restore_collection(coll_bk)

        if training and internal_update:
            assert layer.updates
            with tf.control_dependencies(layer.updates):
                ret = tf.identity(xn, name='output')
        else:
            ret = tf.identity(xn, name='output')

        vh = ret.variables = VariableHolder(
            moving_mean=layer.moving_mean,
            mean=layer.moving_mean,  # for backward-compatibility
            moving_variance=layer.moving_variance,
            variance=layer.moving_variance)  # for backward-compatibility
        if scale:
            vh.gamma = layer.gamma
        if center:
            vh.beta = layer.beta
    else:
        red_axis = [0] if ndims == 2 else (
            [0, 2, 3] if axis == 1 else [0, 1, 2])
        if ndims == 5:
            red_axis = [0, 2, 3, 4] if axis == 1 else [0, 1, 2, 3]
        new_shape = None  # don't need to reshape unless ...
        if ndims == 4 and axis == 1:
            new_shape = [1, num_chan, 1, 1]
        if ndims == 5 and axis == 1:
            new_shape = [1, num_chan, 1, 1, 1]

        batch_mean = tf.reduce_mean(inputs, axis=red_axis)
        batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis)

        if sync_statistics == 'nccl':
            if six.PY3 and TF_version <= 1.8 and ctx.is_main_training_tower:
                logger.warn(
                    "A TensorFlow bug will cause cross-GPU BatchNorm to fail. "
                    "Apply this patch: https://github.com/tensorflow/tensorflow/pull/20360"
                )

            from tensorflow.contrib.nccl.ops import gen_nccl_ops
            shared_name = re.sub('tower[0-9]+/', '',
                                 tf.get_variable_scope().name)
            num_dev = ctx.total
            batch_mean = gen_nccl_ops.nccl_all_reduce(
                input=batch_mean,
                reduction='sum',
                num_devices=num_dev,
                shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev)
            batch_mean_square = gen_nccl_ops.nccl_all_reduce(
                input=batch_mean_square,
                reduction='sum',
                num_devices=num_dev,
                shared_name=shared_name + '_NCCL_mean_square') * (1.0 /
                                                                  num_dev)
        elif sync_statistics == 'horovod':
            # Require https://github.com/uber/horovod/pull/331
            # Proof-of-concept, not ready yet.
            import horovod.tensorflow as hvd
            batch_mean = hvd.allreduce(batch_mean, average=True)
            batch_mean_square = hvd.allreduce(batch_mean_square, average=True)
        batch_var = batch_mean_square - tf.square(batch_mean)
        batch_mean_vec = batch_mean
        batch_var_vec = batch_var

        beta, gamma, moving_mean, moving_var = get_bn_variables(
            num_chan, scale, center, beta_initializer, gamma_initializer)
        if new_shape is not None:
            batch_mean = tf.reshape(batch_mean, new_shape)
            batch_var = tf.reshape(batch_var, new_shape)
            # Using fused_batch_norm(is_training=False) is actually slightly faster,
            # but hopefully this call will be JITed in the future.
            xn = tf.nn.batch_normalization(inputs, batch_mean, batch_var,
                                           tf.reshape(beta, new_shape),
                                           tf.reshape(gamma, new_shape),
                                           epsilon)
        else:
            xn = tf.nn.batch_normalization(inputs, batch_mean, batch_var, beta,
                                           gamma, epsilon)

        if ctx.is_main_training_tower:
            ret = update_bn_ema(xn, batch_mean_vec, batch_var_vec, moving_mean,
                                moving_var, momentum, internal_update)
        else:
            ret = tf.identity(xn, name='output')

        vh = ret.variables = VariableHolder(
            moving_mean=moving_mean,
            mean=moving_mean,  # for backward-compatibility
            moving_variance=moving_var,
            variance=moving_var)  # for backward-compatibility
        if scale:
            vh.gamma = gamma
        if center:
            vh.beta = beta
    return ret
コード例 #6
0
def Conv3D(
        inputs,
        filters,
        kernel_size,
        strides=(1, 1, 1),
        padding='same',
        data_format='channels_last',
        dilation_rate=(1, 1, 1),
        activation=None,
        use_bias=True,
        kernel_initializer=tf.contrib.layers.variance_scaling_initializer(2.0),
        bias_initializer=tf.zeros_initializer(),
        kernel_regularizer=None,
        bias_regularizer=None,
        activity_regularizer=None,
        split=1):
    """
    A wrapper around `tf.layers.Conv3D`.
    Some differences to maintain backward-compatibility:
    1. Default kernel initializer is variance_scaling_initializer(2.0).
    2. Default padding is 'same'.
    3. Support 'split' argument to do group conv.
    Variable Names:
    * ``W``: weights
    * ``b``: bias
    """
    if split == 1:
        with rename_get_variable({'kernel': 'W', 'bias': 'b'}):
            layer = tf.layers.Conv3D(filters,
                                     kernel_size,
                                     strides=strides,
                                     padding=padding,
                                     data_format='channels_last',
                                     dilation_rate=dilation_rate,
                                     activation=activation,
                                     use_bias=use_bias,
                                     kernel_initializer=kernel_initializer,
                                     bias_initializer=bias_initializer,
                                     kernel_regularizer=kernel_regularizer,
                                     bias_regularizer=bias_regularizer,
                                     activity_regularizer=activity_regularizer)
            ret = layer.apply(inputs, scope=tf.get_variable_scope())
            ret = tf.identity(ret, name='output')

        ret.variables = VariableHolder(W=layer.kernel)
        if use_bias:
            ret.variables.b = layer.bias

    else:
        # group conv implementation
        data_format = get_data_format3d(data_format, tfmode=False)
        in_shape = inputs.get_shape().as_list()
        channel_axis = 4 if data_format == 'NDHWC' else 1
        in_channel = in_shape[channel_axis]
        assert in_channel is not None, "[Conv3D] Input cannot have unknown channel!"
        assert in_channel % split == 0

        assert kernel_regularizer is None and bias_regularizer is None and activity_regularizer is None, \
            "Not supported by group conv now!"

        out_channel = filters
        assert out_channel % split == 0
        assert dilation_rate == (1, 1, 1) or get_tf_version_number(
        ) >= 1.5, 'TF>=1.5 required for group dilated conv'

        kernel_shape = shape3d(kernel_size)
        filter_shape = kernel_shape + [in_channel / split, out_channel]
        stride = shape5d(strides, data_format=data_format)

        kwargs = dict(data_format=data_format)
        if get_tf_version_number() >= 1.5:
            kwargs['dilations'] = shape4d(dilation_rate,
                                          data_format=data_format)

        W = tf.get_variable('W', filter_shape, initializer=kernel_initializer)

        if use_bias:
            b = tf.get_variable('b', [out_channel],
                                initializer=bias_initializer)

        inputs = tf.split(inputs, split, channel_axis)
        # tf.split(value,num_or_size_splits,axis=0, num=None,name='split')
        kernels = tf.split(W, split, 4)

        outputs = [
            tf.nn.conv3d(i, k, stride, padding.upper(), **kwargs)
            for i, k in zip(inputs, kernels)
        ]
        conv = tf.concat(outputs, channel_axis)
        if activation is None:
            activation = tf.identity
        ret = activation(tf.nn.bias_add(conv, b, data_format=data_format)
                         if use_bias else conv,
                         name='output')

        ret.variables = VariableHolder(W=W)
        if use_bias:
            ret.variables.b = b
    return ret
コード例 #7
0
def Deconv3D(x,
             out_shape,
             kernel_shape,
             stride,
             padding='SAME',
             W_init=None,
             b_init=None,
             nl=tf.identity,
             use_bias=True,
             data_format='NDHWC'):
    """
    3D deconvolution on 5D inputs.

    Args:
        x (tf.Tensor): a tensor of shape NDHWC.
            Must have known number of channels, but can have other unknown dimensions.
        out_shape: (d, h, w, channel) tuple, or just a integer channel,
            then (d, h, w) will be calculated by input_shape * stride
        kernel_shape: (d, h, w) tuple or a int.
        stride: (h, w) tuple or a int.
        padding (str): 'valid' or 'same'. Case insensitive.
        W_init: initializer for W. Defaults to `variance_scaling_initializer`.
        b_init: initializer for b. Defaults to zero.
        nl: a nonlinearity function.
        use_bias (bool): whether to use bias.

    Returns:
        tf.Tensor: a NDHWC tensor named ``output`` with attribute `variables`.

    Variable Names:

    * ``W``: weights
    * ``b``: bias
    """
    in_shape = x.get_shape().as_list()
    channel_axis = 4 if data_format == 'NDHWC' else 1
    in_channel = in_shape[channel_axis]
    assert in_channel is not None, "[Deconv3D] Input cannot have unknown channel!"
    kernel_shape = shape3d(kernel_shape)
    stride3d = shape3d(stride)
    stride5d = shape5d(stride, data_format=data_format)
    padding = padding.upper()
    in_shape_dyn = tf.shape(x)

    if isinstance(out_shape, int):
        out_channel = out_shape
        if data_format == 'NDHWC':
            shp3_0 = StaticDynamicAxis(
                in_shape[1], in_shape_dyn[1]).apply(lambda x: stride3d[0] * x)
            shp3_1 = StaticDynamicAxis(
                in_shape[2], in_shape_dyn[2]).apply(lambda x: stride3d[1] * x)
            shp3_2 = StaticDynamicAxis(
                in_shape[3], in_shape_dyn[3]).apply(lambda x: stride3d[2] * x)
            shp3_dyn = [
                shp3_0.dynamic, shp3_1.dynamic, shp3_2.dynamic, out_channel
            ]
            shp3_static = [
                shp3_0.static, shp3_1.static, shp3_2.static, out_channel
            ]
        else:
            shp3_0 = StaticDynamicAxis(
                in_shape[2], in_shape_dyn[2]).apply(lambda x: stride3d[0] * x)
            shp3_1 = StaticDynamicAxis(
                in_shape[3], in_shape_dyn[3]).apply(lambda x: stride3d[1] * x)
            shp3_2 = StaticDynamicAxis(
                in_shape[4], in_shape_dyn[4]).apply(lambda x: stride3d[2] * x)
            shp3_dyn = [
                out_channel, shp3_0.dynamic, shp3_1.dynamic, shp3_2.dynamic
            ]
            shp3_static = [
                out_channel, shp3_0.static, shp3_1.static, shp3_2.static
            ]
    else:
        for k in out_shape:
            if not isinstance(k, int):
                raise ValueError(
                    "[Deconv3D] out_shape {} is invalid!".format(k))
        out_channel = out_shape[channel_axis -
                                1]  # out_shape doesn't have batch
        shp3_static = shp3_dyn = out_shape
    filter_shape = kernel_shape + [out_channel, in_channel]

    if W_init is None:
        W_init = tf.contrib.layers.variance_scaling_initializer(
        )  # xavier_initializer_conv2d()
    if b_init is None:
        b_init = tf.constant_initializer()
    W = tf.get_variable('W', filter_shape, initializer=W_init)
    if use_bias:
        b = tf.get_variable('b', [out_channel], initializer=b_init)

    out_shape_dyn = tf.stack([tf.shape(x)[0]] + shp3_dyn)
    conv = tf.nn.conv3d_transpose(x,
                                  W,
                                  out_shape_dyn,
                                  stride5d,
                                  padding=padding,
                                  data_format=data_format)
    conv.set_shape(tf.TensorShape([None] + shp3_static))
    ret = nl(
        tf.nn.bias_add(conv, b, data_format='NDHWC') if use_bias else conv,
        name='output')

    ret.variables = VariableHolder(W=W)
    if use_bias:
        ret.variables.b = b
    return ret
コード例 #8
0
ファイル: custom_ops.py プロジェクト: messiah1999/BraTS19
def BatchNorm3d(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
              center=True, scale=True,
              beta_initializer=tf.zeros_initializer(),
              gamma_initializer=tf.ones_initializer(),
              virtual_batch_size=None,
              data_format='channels_last',
              internal_update=False,
              sync_statistics=None):


    data_format = get_data_format(data_format, tfmode=False)
    shape = inputs.get_shape().as_list()
    ndims = len(shape)
    if sync_statistics is not None:
        sync_statistics = sync_statistics.lower()
    assert sync_statistics in [None, 'nccl', 'horovod'], sync_statistics

    if axis is None:
        if ndims == 2:
            data_format = 'NHWC'
            axis = 1
        elif ndims == 5:
            axis = 1 if data_format == 'NCHW' else 4
        else:
            axis = 1 if data_format == 'NCHW' else 3
    else:
        data_format = 'NCHW' if axis == 1 else 'NHWC'
    num_chan = shape[axis]

    ctx = get_current_tower_context()
    if training is None:
        training = ctx.is_training
    training = bool(training)
    TF_version = get_tf_version_tuple()
    if not training and ctx.is_training:
        assert TF_version >= 1.4
        if ctx.is_main_training_tower: 
            logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.")

    if sync_statistics is None or not (training and ctx.is_training):
        coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS])
        with rename_get_variable(
                {'moving_mean': 'mean/EMA',
                 'moving_variance': 'variance/EMA'}):
            tf_args = dict(
                axis=axis,
                momentum=momentum, epsilon=epsilon,
                center=center, scale=scale,
                beta_initializer=beta_initializer,
                gamma_initializer=gamma_initializer,
                fused=True,
                _reuse=tf.get_variable_scope().reuse)
            if TF_version >= 1.5:
                tf_args['virtual_batch_size'] = virtual_batch_size
            else:
                assert virtual_batch_size is None
            layer = tf.layers.BatchNormalization(**tf_args)
            xn = layer.apply(inputs, training=training, scope=tf.get_variable_scope())

        # maintain EMA only on one GPU
        if ctx.is_main_training_tower:
            for v in layer.non_trainable_variables:
                add_model_variable(v)
        if not ctx.is_main_training_tower or internal_update:
            restore_collection(coll_bk)

        if training and internal_update:
            assert layer.updates
            with tf.control_dependencies(layer.updates):
                ret = tf.identity(xn, name='output')
        else:
            ret = tf.identity(xn, name='output')

        vh = ret.variables = VariableHolder(
            moving_mean=layer.moving_mean,
            mean=layer.moving_mean,  #backward-compatibility
            moving_variance=layer.moving_variance,
            variance=layer.moving_variance)  #backward-compatibility
        if scale:
            vh.gamma = layer.gamma
        if center:
            vh.beta = layer.beta
    else:
        red_axis = [0] if ndims == 2 else ([0, 2, 3] if axis == 1 else [0, 1, 2])
        if ndims == 5:
            red_axis = [0, 2, 3, 4] if axis == 1 else [0, 1, 2, 3]
        new_shape = None 
        if ndims == 4 and axis == 1:
            new_shape = [1, num_chan, 1, 1]
        if ndims == 5 and axis == 1:
            new_shape = [1, num_chan, 1, 1, 1]

        batch_mean = tf.reduce_mean(inputs, axis=red_axis)
        batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis)

        if sync_statistics == 'nccl':
            if six.PY3 and TF_version <= 1.8 and ctx.is_main_training_tower:
                logger.warn("A TensorFlow bug cusing cross-GPU BatchNorm to fail")

            from tensorflow.contrib.nccl.ops import gen_nccl_ops
            shared_name = re.sub('tower[0-9]+/', '', tf.get_variable_scope().name)
            num_dev = ctx.total
            batch_mean = gen_nccl_ops.nccl_all_reduce(
                input=batch_mean,
                reduction='sum',
                num_devices=num_dev,
                shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev)
            batch_mean_square = gen_nccl_ops.nccl_all_reduce(
                input=batch_mean_square,
                reduction='sum',
                num_devices=num_dev,
                shared_name=shared_name + '_NCCL_mean_square') * (1.0 / num_dev)
        elif sync_statistics == 'horovod':
            import horovod.tensorflow as hvd
            batch_mean = hvd.allreduce(batch_mean, average=True)
            batch_mean_square = hvd.allreduce(batch_mean_square, average=True)
        batch_var = batch_mean_square - tf.square(batch_mean)
        batch_mean_vec = batch_mean
        batch_var_vec = batch_var

        beta, gamma, moving_mean, moving_var = get_bn_variables(
            num_chan, scale, center, beta_initializer, gamma_initializer)
        if new_shape is not None:
            batch_mean = tf.reshape(batch_mean, new_shape)
            batch_var = tf.reshape(batch_var, new_shape)
            xn = tf.nn.batch_normalization(
                inputs, batch_mean, batch_var,
                tf.reshape(beta, new_shape),
                tf.reshape(gamma, new_shape), epsilon)
        else:
            xn = tf.nn.batch_normalization(
                inputs, batch_mean, batch_var,
                beta, gamma, epsilon)

        if ctx.is_main_training_tower:
            ret = update_bn_ema(
                xn, batch_mean_vec, batch_var_vec, moving_mean, moving_var,
                momentum, internal_update)
        else:
            ret = tf.identity(xn, name='output')

        vh = ret.variables = VariableHolder(
            moving_mean=moving_mean,
            mean=moving_mean,  
            moving_variance=moving_var,
            variance=moving_var)  
        if scale:
            vh.gamma = gamma
        if center:
            vh.beta = beta
    return ret