def SphericalAdd(x1, x2, theta_mean=0., theta_std=0., use_wscale=True, lrmul=1., adaptive_lr=True, channelwise=True): """y = x1 * cos(theta) + x2 * sin(theta) Special cases: y = x1 if theta == 0 y = x2 if theta == np.pi/2 """ chan = x1.get_shape().as_list()[-1] if channelwise else 1 theta = get_bias(chan, base_std=theta_std, use_wscale=use_wscale, lrmul=lrmul, adaptive_lr=adaptive_lr, name="theta") vh = VariableHolder(theta=theta) theta = theta + theta_mean s1 = tf.math.cos(theta, name="s1") s2 = tf.math.sin(theta, name="s2") ret = tf.identity(tf.add(x1 * s1, x2 * s2), name="output") ret.variables = vh return ret
def InstanceNorm5d(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format='channels_last'): """ Instance Normalization, as in the paper: `Instance Normalization: The Missing Ingredient for Fast Stylization <https://arxiv.org/abs/1607.08022>`_. Args: x (tf.Tensor): a 4D tensor. epsilon (float): avoid divide-by-zero use_affine (bool): whether to apply learnable affine transformation """ data_format = get_data_format(data_format, keras_mode=True) shape = x.get_shape().as_list() # assert len(shape) == 4, "Input of InstanceNorm has to be 4D!" if len(shape) == 5: if data_format == 'NHWC': axis = [1, 2, 3] ch = shape[4] new_shape = [1, 1, 1, 1, ch] else: axis = [2, 3, 4] ch = shape[1] new_shape = [1, ch, 1, 1, 1] else: if data_format == 'NHWC': axis = [1, 2] ch = shape[3] new_shape = [1, 1, 1, ch] else: axis = [2, 3] ch = shape[1] new_shape = [1, ch, 1, 1] assert ch is not None, "Input of InstanceNorm require known channel!" mean, var = tf.nn.moments(x, axis, keep_dims=True) if not use_affine: return tf.divide(x - mean, tf.sqrt(var + epsilon), name='output') beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer()) beta = tf.reshape(beta, new_shape) if gamma_init is None: gamma_init = tf.constant_initializer(1.0) gamma = tf.get_variable('gamma', [ch], initializer=gamma_init) gamma = tf.reshape(gamma, new_shape) ret = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output') vh = ret.variables = VariableHolder() if use_affine: vh.gamma = gamma vh.beta = beta return ret
def InstanceNorm(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format='channels_last'): data_format = get_data_format(data_format, tfmode=False) shape = x.get_shape().as_list() if len(shape) == 5: if data_format == 'NHWC': axis = [1, 2, 3] ch = shape[4] new_shape = [1, 1, 1, 1, ch] else: axis = [2, 3, 4] ch = shape[1] new_shape = [1, ch, 1, 1, 1] else: if data_format == 'NHWC': axis = [1, 2] ch = shape[3] new_shape = [1, 1, 1, ch] else: axis = [2, 3] ch = shape[1] new_shape = [1, ch, 1, 1] assert ch is not None, mean, var = tf.nn.moments(x, axis, keep_dims=True) if not use_affine: return tf.divide(x - mean, tf.sqrt(var + epsilon), name='output') beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer()) beta = tf.reshape(beta, new_shape) if gamma_init is None: gamma_init = tf.constant_initializer(1.0) gamma = tf.get_variable('gamma', [ch], initializer=gamma_init) gamma = tf.reshape(gamma, new_shape) ret = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output') vh = ret.variables = VariableHolder() if use_affine: vh.gamma = gamma vh.beta = beta return ret
def InstanceNorm5d(x, epsilon=1e-5, use_affine=True, gamma_init=None, data_format='channels_last'): shape = x.get_shape().as_list() # assert len(shape) == 4, "Input of InstanceNorm has to be 4D!" axis = [1, 2, 3] ch = shape[4] new_shape = [1, 1, 1, 1, ch] mean, var = tf.nn.moments(x, axis, keep_dims=True) if not use_affine: return tf.divide(x - mean, tf.sqrt(var + epsilon), name='output') beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer()) beta = tf.reshape(beta, new_shape) if gamma_init is None: gamma_init = tf.constant_initializer(1.0) gamma = tf.get_variable('gamma', [ch], initializer=gamma_init) gamma = tf.reshape(gamma, new_shape) ret = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name='output') vh = ret.variables = VariableHolder() if use_affine: vh.gamma = gamma vh.beta = beta return ret
def BatchNorm3d(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, center=True, scale=True, beta_initializer=tf.zeros_initializer(), gamma_initializer=tf.ones_initializer(), virtual_batch_size=None, data_format='channels_last', internal_update=False, sync_statistics=None): """ Almost equivalent to `tf.layers.batch_normalization`, but different (and more powerful) in the following: 1. Accepts an alternative `data_format` option when `axis` is None. For 2D input, this argument will be ignored. 2. Default value for `momentum` and `epsilon` is different. 3. Default value for `training` is automatically obtained from tensorpack's `TowerContext`, but can be overwritten. 4. Support the `internal_update` option, which enables the use of BatchNorm layer inside conditionals. 5. Support the `sync_statistics` option, which is very useful in small-batch models. Args: internal_update (bool): if False, add EMA update ops to `tf.GraphKeys.UPDATE_OPS`. If True, update EMA inside the layer by control dependencies. They are very similar in speed, but `internal_update=True` can be used when you have conditionals in your model, or when you have multiple networks to train. Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/14699 sync_statistics: either None or "nccl". By default (None), it uses statistics of the input tensor to normalize. When set to "nccl", this layer must be used under tensorpack multi-gpu trainers, and it then uses per-machine (multiple GPU) statistics to normalize. Note that this implementation averages the per-tower E[x] and E[x^2] among towers to compute global mean&variance. The result is the global mean&variance only if each tower has the same batch size. This option has no effect when not training. This option is also known as "Cross-GPU BatchNorm" as mentioned in https://arxiv.org/abs/1711.07240. Corresponding TF issue: https://github.com/tensorflow/tensorflow/issues/18222 Variable Names: * ``beta``: the bias term. Will be zero-inited by default. * ``gamma``: the scale term. Will be one-inited by default. * ``mean/EMA``: the moving average of mean. * ``variance/EMA``: the moving average of variance. Note: Combinations of ``training`` and ``ctx.is_training``: * ``training == ctx.is_training``: standard BN, EMA are maintained during training and used during inference. This is the default. * ``training and not ctx.is_training``: still use batch statistics in inference. * ``not training and ctx.is_training``: use EMA to normalize in training. This is useful when you load a pre-trained BN and don't want to fine tune the EMA. EMA will not be updated in this case. """ # parse shapes data_format = get_data_format(data_format, tfmode=False) shape = inputs.get_shape().as_list() ndims = len(shape) # in 3d conv, we have 5d dim [batch, c, d, h, w] # assert ndims in [2, 4], ndims if sync_statistics is not None: sync_statistics = sync_statistics.lower() assert sync_statistics in [None, 'nccl', 'horovod'], sync_statistics if axis is None: if ndims == 2: data_format = 'NHWC' axis = 1 elif ndims == 5: axis = 1 if data_format == 'NCHW' else 4 else: axis = 1 if data_format == 'NCHW' else 3 else: data_format = 'NCHW' if axis == 1 else 'NHWC' num_chan = shape[axis] # parse training/ctx ctx = get_current_tower_context() if training is None: training = ctx.is_training training = bool(training) TF_version = get_tf_version_number() if not training and ctx.is_training: assert TF_version >= 1.4, \ "Fine tuning a BatchNorm model with fixed statistics is only " \ "supported after https://github.com/tensorflow/tensorflow/pull/12580 " if ctx.is_main_training_tower: # only warn in first tower logger.warn( "[BatchNorm] Using moving_mean/moving_variance in training.") # Using moving_mean/moving_variance in training, which means we # loaded a pre-trained BN and only fine-tuning the affine part. if sync_statistics is None or not (training and ctx.is_training): coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS]) with rename_get_variable({ 'moving_mean': 'mean/EMA', 'moving_variance': 'variance/EMA' }): tf_args = dict(axis=axis, momentum=momentum, epsilon=epsilon, center=center, scale=scale, beta_initializer=beta_initializer, gamma_initializer=gamma_initializer, fused=True, _reuse=tf.get_variable_scope().reuse) if TF_version >= 1.5: tf_args['virtual_batch_size'] = virtual_batch_size else: assert virtual_batch_size is None, "Feature not supported in this version of TF!" layer = tf.layers.BatchNormalization(**tf_args) xn = layer.apply(inputs, training=training, scope=tf.get_variable_scope()) # maintain EMA only on one GPU is OK, even in replicated mode. # because during training, EMA isn't used if ctx.is_main_training_tower: for v in layer.non_trainable_variables: add_model_variable(v) if not ctx.is_main_training_tower or internal_update: restore_collection(coll_bk) if training and internal_update: assert layer.updates with tf.control_dependencies(layer.updates): ret = tf.identity(xn, name='output') else: ret = tf.identity(xn, name='output') vh = ret.variables = VariableHolder( moving_mean=layer.moving_mean, mean=layer.moving_mean, # for backward-compatibility moving_variance=layer.moving_variance, variance=layer.moving_variance) # for backward-compatibility if scale: vh.gamma = layer.gamma if center: vh.beta = layer.beta else: red_axis = [0] if ndims == 2 else ( [0, 2, 3] if axis == 1 else [0, 1, 2]) if ndims == 5: red_axis = [0, 2, 3, 4] if axis == 1 else [0, 1, 2, 3] new_shape = None # don't need to reshape unless ... if ndims == 4 and axis == 1: new_shape = [1, num_chan, 1, 1] if ndims == 5 and axis == 1: new_shape = [1, num_chan, 1, 1, 1] batch_mean = tf.reduce_mean(inputs, axis=red_axis) batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis) if sync_statistics == 'nccl': if six.PY3 and TF_version <= 1.8 and ctx.is_main_training_tower: logger.warn( "A TensorFlow bug will cause cross-GPU BatchNorm to fail. " "Apply this patch: https://github.com/tensorflow/tensorflow/pull/20360" ) from tensorflow.contrib.nccl.ops import gen_nccl_ops shared_name = re.sub('tower[0-9]+/', '', tf.get_variable_scope().name) num_dev = ctx.total batch_mean = gen_nccl_ops.nccl_all_reduce( input=batch_mean, reduction='sum', num_devices=num_dev, shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev) batch_mean_square = gen_nccl_ops.nccl_all_reduce( input=batch_mean_square, reduction='sum', num_devices=num_dev, shared_name=shared_name + '_NCCL_mean_square') * (1.0 / num_dev) elif sync_statistics == 'horovod': # Require https://github.com/uber/horovod/pull/331 # Proof-of-concept, not ready yet. import horovod.tensorflow as hvd batch_mean = hvd.allreduce(batch_mean, average=True) batch_mean_square = hvd.allreduce(batch_mean_square, average=True) batch_var = batch_mean_square - tf.square(batch_mean) batch_mean_vec = batch_mean batch_var_vec = batch_var beta, gamma, moving_mean, moving_var = get_bn_variables( num_chan, scale, center, beta_initializer, gamma_initializer) if new_shape is not None: batch_mean = tf.reshape(batch_mean, new_shape) batch_var = tf.reshape(batch_var, new_shape) # Using fused_batch_norm(is_training=False) is actually slightly faster, # but hopefully this call will be JITed in the future. xn = tf.nn.batch_normalization(inputs, batch_mean, batch_var, tf.reshape(beta, new_shape), tf.reshape(gamma, new_shape), epsilon) else: xn = tf.nn.batch_normalization(inputs, batch_mean, batch_var, beta, gamma, epsilon) if ctx.is_main_training_tower: ret = update_bn_ema(xn, batch_mean_vec, batch_var_vec, moving_mean, moving_var, momentum, internal_update) else: ret = tf.identity(xn, name='output') vh = ret.variables = VariableHolder( moving_mean=moving_mean, mean=moving_mean, # for backward-compatibility moving_variance=moving_var, variance=moving_var) # for backward-compatibility if scale: vh.gamma = gamma if center: vh.beta = beta return ret
def Conv3D( inputs, filters, kernel_size, strides=(1, 1, 1), padding='same', data_format='channels_last', dilation_rate=(1, 1, 1), activation=None, use_bias=True, kernel_initializer=tf.contrib.layers.variance_scaling_initializer(2.0), bias_initializer=tf.zeros_initializer(), kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, split=1): """ A wrapper around `tf.layers.Conv3D`. Some differences to maintain backward-compatibility: 1. Default kernel initializer is variance_scaling_initializer(2.0). 2. Default padding is 'same'. 3. Support 'split' argument to do group conv. Variable Names: * ``W``: weights * ``b``: bias """ if split == 1: with rename_get_variable({'kernel': 'W', 'bias': 'b'}): layer = tf.layers.Conv3D(filters, kernel_size, strides=strides, padding=padding, data_format='channels_last', dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer) ret = layer.apply(inputs, scope=tf.get_variable_scope()) ret = tf.identity(ret, name='output') ret.variables = VariableHolder(W=layer.kernel) if use_bias: ret.variables.b = layer.bias else: # group conv implementation data_format = get_data_format3d(data_format, tfmode=False) in_shape = inputs.get_shape().as_list() channel_axis = 4 if data_format == 'NDHWC' else 1 in_channel = in_shape[channel_axis] assert in_channel is not None, "[Conv3D] Input cannot have unknown channel!" assert in_channel % split == 0 assert kernel_regularizer is None and bias_regularizer is None and activity_regularizer is None, \ "Not supported by group conv now!" out_channel = filters assert out_channel % split == 0 assert dilation_rate == (1, 1, 1) or get_tf_version_number( ) >= 1.5, 'TF>=1.5 required for group dilated conv' kernel_shape = shape3d(kernel_size) filter_shape = kernel_shape + [in_channel / split, out_channel] stride = shape5d(strides, data_format=data_format) kwargs = dict(data_format=data_format) if get_tf_version_number() >= 1.5: kwargs['dilations'] = shape4d(dilation_rate, data_format=data_format) W = tf.get_variable('W', filter_shape, initializer=kernel_initializer) if use_bias: b = tf.get_variable('b', [out_channel], initializer=bias_initializer) inputs = tf.split(inputs, split, channel_axis) # tf.split(value,num_or_size_splits,axis=0, num=None,name='split') kernels = tf.split(W, split, 4) outputs = [ tf.nn.conv3d(i, k, stride, padding.upper(), **kwargs) for i, k in zip(inputs, kernels) ] conv = tf.concat(outputs, channel_axis) if activation is None: activation = tf.identity ret = activation(tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv, name='output') ret.variables = VariableHolder(W=W) if use_bias: ret.variables.b = b return ret
def Deconv3D(x, out_shape, kernel_shape, stride, padding='SAME', W_init=None, b_init=None, nl=tf.identity, use_bias=True, data_format='NDHWC'): """ 3D deconvolution on 5D inputs. Args: x (tf.Tensor): a tensor of shape NDHWC. Must have known number of channels, but can have other unknown dimensions. out_shape: (d, h, w, channel) tuple, or just a integer channel, then (d, h, w) will be calculated by input_shape * stride kernel_shape: (d, h, w) tuple or a int. stride: (h, w) tuple or a int. padding (str): 'valid' or 'same'. Case insensitive. W_init: initializer for W. Defaults to `variance_scaling_initializer`. b_init: initializer for b. Defaults to zero. nl: a nonlinearity function. use_bias (bool): whether to use bias. Returns: tf.Tensor: a NDHWC tensor named ``output`` with attribute `variables`. Variable Names: * ``W``: weights * ``b``: bias """ in_shape = x.get_shape().as_list() channel_axis = 4 if data_format == 'NDHWC' else 1 in_channel = in_shape[channel_axis] assert in_channel is not None, "[Deconv3D] Input cannot have unknown channel!" kernel_shape = shape3d(kernel_shape) stride3d = shape3d(stride) stride5d = shape5d(stride, data_format=data_format) padding = padding.upper() in_shape_dyn = tf.shape(x) if isinstance(out_shape, int): out_channel = out_shape if data_format == 'NDHWC': shp3_0 = StaticDynamicAxis( in_shape[1], in_shape_dyn[1]).apply(lambda x: stride3d[0] * x) shp3_1 = StaticDynamicAxis( in_shape[2], in_shape_dyn[2]).apply(lambda x: stride3d[1] * x) shp3_2 = StaticDynamicAxis( in_shape[3], in_shape_dyn[3]).apply(lambda x: stride3d[2] * x) shp3_dyn = [ shp3_0.dynamic, shp3_1.dynamic, shp3_2.dynamic, out_channel ] shp3_static = [ shp3_0.static, shp3_1.static, shp3_2.static, out_channel ] else: shp3_0 = StaticDynamicAxis( in_shape[2], in_shape_dyn[2]).apply(lambda x: stride3d[0] * x) shp3_1 = StaticDynamicAxis( in_shape[3], in_shape_dyn[3]).apply(lambda x: stride3d[1] * x) shp3_2 = StaticDynamicAxis( in_shape[4], in_shape_dyn[4]).apply(lambda x: stride3d[2] * x) shp3_dyn = [ out_channel, shp3_0.dynamic, shp3_1.dynamic, shp3_2.dynamic ] shp3_static = [ out_channel, shp3_0.static, shp3_1.static, shp3_2.static ] else: for k in out_shape: if not isinstance(k, int): raise ValueError( "[Deconv3D] out_shape {} is invalid!".format(k)) out_channel = out_shape[channel_axis - 1] # out_shape doesn't have batch shp3_static = shp3_dyn = out_shape filter_shape = kernel_shape + [out_channel, in_channel] if W_init is None: W_init = tf.contrib.layers.variance_scaling_initializer( ) # xavier_initializer_conv2d() if b_init is None: b_init = tf.constant_initializer() W = tf.get_variable('W', filter_shape, initializer=W_init) if use_bias: b = tf.get_variable('b', [out_channel], initializer=b_init) out_shape_dyn = tf.stack([tf.shape(x)[0]] + shp3_dyn) conv = tf.nn.conv3d_transpose(x, W, out_shape_dyn, stride5d, padding=padding, data_format=data_format) conv.set_shape(tf.TensorShape([None] + shp3_static)) ret = nl( tf.nn.bias_add(conv, b, data_format='NDHWC') if use_bias else conv, name='output') ret.variables = VariableHolder(W=W) if use_bias: ret.variables.b = b return ret
def BatchNorm3d(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, center=True, scale=True, beta_initializer=tf.zeros_initializer(), gamma_initializer=tf.ones_initializer(), virtual_batch_size=None, data_format='channels_last', internal_update=False, sync_statistics=None): data_format = get_data_format(data_format, tfmode=False) shape = inputs.get_shape().as_list() ndims = len(shape) if sync_statistics is not None: sync_statistics = sync_statistics.lower() assert sync_statistics in [None, 'nccl', 'horovod'], sync_statistics if axis is None: if ndims == 2: data_format = 'NHWC' axis = 1 elif ndims == 5: axis = 1 if data_format == 'NCHW' else 4 else: axis = 1 if data_format == 'NCHW' else 3 else: data_format = 'NCHW' if axis == 1 else 'NHWC' num_chan = shape[axis] ctx = get_current_tower_context() if training is None: training = ctx.is_training training = bool(training) TF_version = get_tf_version_tuple() if not training and ctx.is_training: assert TF_version >= 1.4 if ctx.is_main_training_tower: logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.") if sync_statistics is None or not (training and ctx.is_training): coll_bk = backup_collection([tf.GraphKeys.UPDATE_OPS]) with rename_get_variable( {'moving_mean': 'mean/EMA', 'moving_variance': 'variance/EMA'}): tf_args = dict( axis=axis, momentum=momentum, epsilon=epsilon, center=center, scale=scale, beta_initializer=beta_initializer, gamma_initializer=gamma_initializer, fused=True, _reuse=tf.get_variable_scope().reuse) if TF_version >= 1.5: tf_args['virtual_batch_size'] = virtual_batch_size else: assert virtual_batch_size is None layer = tf.layers.BatchNormalization(**tf_args) xn = layer.apply(inputs, training=training, scope=tf.get_variable_scope()) # maintain EMA only on one GPU if ctx.is_main_training_tower: for v in layer.non_trainable_variables: add_model_variable(v) if not ctx.is_main_training_tower or internal_update: restore_collection(coll_bk) if training and internal_update: assert layer.updates with tf.control_dependencies(layer.updates): ret = tf.identity(xn, name='output') else: ret = tf.identity(xn, name='output') vh = ret.variables = VariableHolder( moving_mean=layer.moving_mean, mean=layer.moving_mean, #backward-compatibility moving_variance=layer.moving_variance, variance=layer.moving_variance) #backward-compatibility if scale: vh.gamma = layer.gamma if center: vh.beta = layer.beta else: red_axis = [0] if ndims == 2 else ([0, 2, 3] if axis == 1 else [0, 1, 2]) if ndims == 5: red_axis = [0, 2, 3, 4] if axis == 1 else [0, 1, 2, 3] new_shape = None if ndims == 4 and axis == 1: new_shape = [1, num_chan, 1, 1] if ndims == 5 and axis == 1: new_shape = [1, num_chan, 1, 1, 1] batch_mean = tf.reduce_mean(inputs, axis=red_axis) batch_mean_square = tf.reduce_mean(tf.square(inputs), axis=red_axis) if sync_statistics == 'nccl': if six.PY3 and TF_version <= 1.8 and ctx.is_main_training_tower: logger.warn("A TensorFlow bug cusing cross-GPU BatchNorm to fail") from tensorflow.contrib.nccl.ops import gen_nccl_ops shared_name = re.sub('tower[0-9]+/', '', tf.get_variable_scope().name) num_dev = ctx.total batch_mean = gen_nccl_ops.nccl_all_reduce( input=batch_mean, reduction='sum', num_devices=num_dev, shared_name=shared_name + '_NCCL_mean') * (1.0 / num_dev) batch_mean_square = gen_nccl_ops.nccl_all_reduce( input=batch_mean_square, reduction='sum', num_devices=num_dev, shared_name=shared_name + '_NCCL_mean_square') * (1.0 / num_dev) elif sync_statistics == 'horovod': import horovod.tensorflow as hvd batch_mean = hvd.allreduce(batch_mean, average=True) batch_mean_square = hvd.allreduce(batch_mean_square, average=True) batch_var = batch_mean_square - tf.square(batch_mean) batch_mean_vec = batch_mean batch_var_vec = batch_var beta, gamma, moving_mean, moving_var = get_bn_variables( num_chan, scale, center, beta_initializer, gamma_initializer) if new_shape is not None: batch_mean = tf.reshape(batch_mean, new_shape) batch_var = tf.reshape(batch_var, new_shape) xn = tf.nn.batch_normalization( inputs, batch_mean, batch_var, tf.reshape(beta, new_shape), tf.reshape(gamma, new_shape), epsilon) else: xn = tf.nn.batch_normalization( inputs, batch_mean, batch_var, beta, gamma, epsilon) if ctx.is_main_training_tower: ret = update_bn_ema( xn, batch_mean_vec, batch_var_vec, moving_mean, moving_var, momentum, internal_update) else: ret = tf.identity(xn, name='output') vh = ret.variables = VariableHolder( moving_mean=moving_mean, mean=moving_mean, moving_variance=moving_var, variance=moving_var) if scale: vh.gamma = gamma if center: vh.beta = beta return ret