def testSplitWithNonConstAxis(self): if test.is_gpu_available(cuda_only=True): random_seed.set_random_seed(0) x = random_ops.truncated_normal([1, 784], seed=0) conv = _two_layer_model(x) dim = array_ops.placeholder(dtype='int32') split = array_ops.split(conv, 2, axis=dim) scale = constant_op.constant(0.1, shape=[32]) offset = constant_op.constant(0.3, shape=[32]) bn0 = nn.fused_batch_norm(split[0], scale, offset) bn1 = nn.fused_batch_norm(split[1], scale, offset) add = bn0[0] + bn1[0] output = array_ops.identity(add) with session.Session() as sess: output_val_ref = sess.run(output, feed_dict={dim: 3}) with session.Session(config=_get_config()) as sess: metadata = config_pb2.RunMetadata() output_val = sess.run(output, run_metadata=metadata, feed_dict={dim: 3}) nodes = [] num_transposes = 0 for node in metadata.cost_graph.node: if _is_transpose(node.name): num_transposes += 1 nodes.append(node.name) expected_num_transposes = 2 self.assertEqual(expected_num_transposes, num_transposes) self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) self._assert_trans_nchw_to_nhwc('add_2-0-0', nodes) self._assert_map_nhwc_to_nchw('split-0', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3)
def call(self, inputs, training=None): training = self._get_training_value(training) inputs = tf.cast( inputs, tf.keras.mixed_precision.global_policy().variable_dtype) if training: outputs, mean, variance = nn.fused_batch_norm(inputs, self.gamma, self.beta, epsilon=self.epsilon) else: outputs, mean, variance = nn.fused_batch_norm( inputs, self.gamma, self.beta, mean=self.moving_mean, variance=self.moving_variance, epsilon=self.epsilon, is_training=False) outputs = tf.cast( outputs, tf.keras.mixed_precision.global_policy().compute_dtype) @tf.custom_gradient def moving_avg_updates(x, moving_m, moving_v): def bw(dx): return dx, moving_m - mean, moving_v - variance return x, bw return moving_avg_updates(outputs, self.moving_mean, self.moving_variance)
def testInference(self): channel = 3 x_shape = [2, 2, 6, channel] scale_shape = [channel] x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) offset_val = np.random.random_sample(scale_shape).astype(np.float32) data_format = "NHWC" with self.test_session() as sess, self.test_scope(): # To avoid constant folding t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset") epsilon = 0.001 y_ref, mean_ref, var_ref = self._reference_training( x_val, scale_val, offset_val, epsilon, data_format) y, mean, variance = nn.fused_batch_norm( t_val, scale, offset, mean=mean_ref, variance=var_ref, epsilon=epsilon, data_format=data_format, is_training=False) y_val, _, _ = sess.run( [y, mean, variance], {t_val: x_val, scale: scale_val, offset: offset_val}) self.assertAllClose(y_val, y_ref, atol=1e-3)
def testBasic(self): x_shape = [2, 2, 6, 2] scale_shape = [2] x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) offset_val = np.random.random_sample(scale_shape).astype(np.float32) mean_val = np.random.random_sample(scale_shape).astype(np.float32) var_val = np.random.random_sample(scale_shape).astype(np.float32) data_format = "NHWC" with self.test_session() as sess, self.test_scope(): # To avoid constant folding t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") scale = array_ops.placeholder(np.float32, shape=[2], name="scale") offset = array_ops.placeholder(np.float32, shape=[2], name="offset") epsilon = 0.001 y, mean, var = nn.fused_batch_norm( t_val, scale, offset, mean=None, variance=None, epsilon=epsilon, data_format=data_format, is_training=True) y_val, mean_val, var_val = sess.run( [y, mean, var], {t_val: x_val, scale: scale_val, offset: offset_val}) y_ref, mean_ref, var_ref = self._reference_training( x_val, scale_val, offset_val, epsilon, data_format) self.assertAllClose(mean_val, mean_ref, atol=1e-3) self.assertAllClose(y_val, y_ref, atol=1e-3) self.assertAllClose(var_val, var_ref, atol=1e-3)
def _testLearning(self, use_gradient_checker, data_format): channel = 3 x_shape = [2, 2, 6, channel] scale_shape = [channel] x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) offset_val = np.random.random_sample(scale_shape).astype(np.float32) mean_val = np.random.random_sample(scale_shape).astype(np.float32) var_val = np.random.random_sample(scale_shape).astype(np.float32) epsilon = 0.001 data_format_src = "NHWC" y_ref, mean_ref, var_ref = self._reference_training( x_val, scale_val, offset_val, epsilon, data_format_src) with self.cached_session() as sess, self.test_scope(): # To avoid constant folding x_val_converted = test_utils.ConvertBetweenDataFormats( x_val, data_format_src, data_format) y_ref_converted = test_utils.ConvertBetweenDataFormats( y_ref, data_format_src, data_format) t_val = array_ops.placeholder(np.float32, shape=x_val_converted.shape, name="x") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset") y, mean, var = nn.fused_batch_norm(t_val, scale, offset, mean=None, variance=None, epsilon=epsilon, data_format=data_format, is_training=True) # Check gradient. if use_gradient_checker: err = gradient_checker.compute_gradient_error( t_val, x_val_converted.shape, y, x_val_converted.shape, extra_feed_dict={ t_val: x_val_converted, scale: scale_val, offset: offset_val }) self.assertLess(err, 1e-3) y_val, mean_val, var_val = sess.run([y, mean, var], { t_val: x_val_converted, scale: scale_val, offset: offset_val }) self.assertAllClose(mean_val, mean_ref, atol=1e-3) self.assertAllClose(y_val, y_ref_converted, atol=1e-3) self.assertAllClose(var_val, var_ref, atol=1e-3)
def loop_fn(i): with g: x1 = array_ops.gather(x, i) outputs = nn.fused_batch_norm( x1, scale, offset, mean=mean, variance=variance, epsilon=0.01, data_format=data_format, is_training=is_training) outputs = list(outputs) # We only test the first value of outputs when is_training is # False. It looks like CPU and GPU have different outputs for # batch_mean and batch_variance for this case. if not is_training: outputs[1] = constant_op.constant(0.) outputs[2] = constant_op.constant(0.) loss = nn.l2_loss(outputs[0]) if is_training: gradients = g.gradient(loss, [x1, scale, offset]) else: gradients = [constant_op.constant(0.)] * 3 return outputs + gradients
def my_graph(a): with ops.device("/device:IPU:0"): with variable_scope.variable_scope("", use_resource=True): beta = variable_scope.get_variable( "x", dtype=np.float32, shape=[4], initializer=init_ops.constant_initializer(0.0)) gamma = variable_scope.get_variable( "y", dtype=np.float32, shape=[4], initializer=init_ops.constant_initializer(1.0)) b_mean, b_var = nn.moments(a, [0, 1, 2], name='moments') normed = nn.fused_batch_norm(a, gamma, beta, b_mean, b_var, is_training=False) return normed
def testInference(self): channel = 3 x_shape = [2, 2, 6, channel] scale_shape = [channel] x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) offset_val = np.random.random_sample(scale_shape).astype(np.float32) data_format = "NHWC" with self.test_session() as sess, self.test_scope(): # To avoid constant folding t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") offset = array_ops.placeholder( np.float32, shape=scale_shape, name="offset") epsilon = 0.001 y_ref, mean_ref, var_ref = self._reference_training( x_val, scale_val, offset_val, epsilon, data_format) y, mean, variance = nn.fused_batch_norm( t_val, scale, offset, mean=mean_ref, variance=var_ref, epsilon=epsilon, data_format=data_format, is_training=False) y_val, _, _ = sess.run( [y, mean, variance], {t_val: x_val, scale: scale_val, offset: offset_val}) self.assertAllClose(y_val, y_ref, atol=1e-3)
def _ComputeBN(self, inputs, paddings, gamma, beta, norm_mean, norm_variance): p = self.params with tf.control_dependencies([ py_utils.assert_greater_equal(norm_variance, tf.zeros_like(norm_variance)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_mean)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_variance)), ]): if p.use_fused_batch_norm_for_eval and (self.do_eval or p.freeze_bn_stats): bn_output, _, _ = nn.fused_batch_norm(inputs, gamma, beta, norm_mean, norm_variance, self._epsilon, is_training=False) else: bn_output = tf.nn.batch_normalization(inputs, norm_mean, norm_variance, beta, gamma, self._epsilon) if p.set_padded_output_to_zero: bn_output = py_utils.ApplyPadding(paddings, bn_output) return bn_output
def _fused_batch_norm_training(): return nn.fused_batch_norm( inputs, gamma, beta, epsilon=self.epsilon, data_format=self._data_format)
def loop_fn(i): with g: x1 = array_ops.gather(x, i) outputs = nn.fused_batch_norm( x1, scale, offset, mean=mean, variance=variance, epsilon=0.01, data_format=data_format, is_training=is_training) outputs = list(outputs) # We only test the first value of outputs when is_training is False. # It looks like CPU and GPU have different outputs for batch_mean # and batch_variance for this case. if not is_training: outputs[1] = constant_op.constant(0.) outputs[2] = constant_op.constant(0.) loss = nn.l2_loss(outputs[0]) if is_training: gradients = g.gradient(loss, [x1, scale, offset]) else: gradients = [constant_op.constant(0.)] * 3 return outputs + gradients
def _testLearning(self, use_gradient_checker, data_format): channel = 3 x_shape = [2, 2, 6, channel] scale_shape = [channel] x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) offset_val = np.random.random_sample(scale_shape).astype(np.float32) mean_val = np.random.random_sample(scale_shape).astype(np.float32) var_val = np.random.random_sample(scale_shape).astype(np.float32) epsilon = 0.001 data_format_src = "NHWC" # When in training mode, fused_batchnorm applies an implicit Bessel's # correction. So we have to use the corrected variance here, as well. y_ref, mean_ref, _, var_ref_corr = self._reference_training( x_val, scale_val, offset_val, epsilon, data_format_src) with self.cached_session() as sess, self.test_scope(): # To avoid constant folding x_val_converted = test_utils.ConvertBetweenDataFormats( x_val, data_format_src, data_format) y_ref_converted = test_utils.ConvertBetweenDataFormats( y_ref, data_format_src, data_format) t_val = array_ops.placeholder( np.float32, shape=x_val_converted.shape, name="x") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") offset = array_ops.placeholder( np.float32, shape=scale_shape, name="offset") y, mean, var = nn.fused_batch_norm( t_val, scale, offset, mean=None, variance=None, epsilon=epsilon, data_format=data_format, is_training=True) # Check gradient. if use_gradient_checker: err = gradient_checker.compute_gradient_error( t_val, x_val_converted.shape, y, x_val_converted.shape, extra_feed_dict={ t_val: x_val_converted, scale: scale_val, offset: offset_val }) self.assertLess(err, 1e-3) y_val, mean_val, var_val = sess.run([y, mean, var], { t_val: x_val_converted, scale: scale_val, offset: offset_val }) self.assertAllClose(mean_val, mean_ref, atol=1e-3) self.assertAllClose(y_val, y_ref_converted, atol=1e-3) self.assertAllClose(var_val, var_ref_corr, atol=1e-3)
def _model_with_second_port(): random_seed.set_random_seed(0) x = random_ops.truncated_normal([2, 5, 5, 4], seed=0) scale = constant_op.constant(0.1, shape=[4]) offset = constant_op.constant(0.3, shape=[4]) y, mean, _ = nn.fused_batch_norm(x, scale, offset) mul = math_ops.add(y, mean) output = array_ops.identity(mul) return output
def _fused_batch_norm_inference(): return nn.fused_batch_norm(inputs, gamma, beta, mean=self.moving_mean, variance=self.moving_variance, epsilon=self.epsilon, is_training=False, data_format=self._data_format)
def my_batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3): """Applies batch normalization on x given mean, var, beta and gamma. I.e. returns: `output = (x - mean) / (sqrt(var) + epsilon) * gamma + beta` Arguments: x: Input tensor or variable. mean: Mean of batch. var: Variance of batch. beta: Tensor with which to center the input. gamma: Tensor by which to scale the input. axis: Integer, the axis that should be normalized. (typically the features axis). epsilon: Fuzz factor. Returns: A tensor. """ if K.ndim(x) == 4: print("hey") # The CPU implementation of `fused_batch_norm` only supports NHWC if axis == 1 or axis == -3: tf_data_format = 'NCHW' elif axis == 3 or axis == -1: tf_data_format = 'NHWC' else: tf_data_format = None if (tf_data_format == 'NHWC' or tf_data_format == 'NCHW' and _has_nchw_support()): # The mean / var / beta / gamma tensors may be broadcasted # so they may have extra axes of size 1, which should be squeezed. if K.ndim(mean) > 1: mean = array_ops.reshape(mean, [-1]) if K.ndim(var) > 1: var = array_ops.reshape(var, [-1]) if beta is None: beta = zeros_like(mean) elif K.ndim(beta) > 1: beta = array_ops.reshape(beta, [-1]) if gamma is None: gamma = ones_like(mean) elif K.ndim(gamma) > 1: gamma = array_ops.reshape(gamma, [-1]) y, _, _ = nn.fused_batch_norm(x, gamma, beta, epsilon=epsilon, mean=mean, variance=var, data_format=tf_data_format, is_training=False) return y return tf.map_fn( lambda xx: nn.batch_normalization(xx, mean, var, beta, gamma, epsilon), x)
def _fused_batch_norm_inference(): return nn.fused_batch_norm( inputs, gamma, beta, mean=self.moving_mean, variance=self.moving_variance, epsilon=self.epsilon, is_training=False, data_format=self._data_format)
def template(x_shape=[2, 3, 4, 5], data_format="NHWC", description: str = ""): from tensorflow.python.ops import nn x = tf.placeholder(np.float32, x_shape) scale = tf.placeholder( np.float32, x_shape[-1] if data_format == "NHWC" else x_shape[1]) bias = tf.placeholder(np.float32, x_shape[-1] if data_format == "NHWC" else x_shape[1]) mean = tf.placeholder(np.float32, x_shape[-1] if data_format == "NHWC" else x_shape[1]) variance = tf.placeholder( np.float32, x_shape[-1] if data_format == "NHWC" else x_shape[1]) y, _, _ = nn.fused_batch_norm(x, scale, bias, mean, variance, data_format=data_format, is_training=False) vx = np.random.rand(*x_shape).astype(np.float32) vs = np.random.rand( *[x_shape[-1] if data_format == "NHWC" else x_shape[1]]).astype( np.float32) vb = np.random.rand( *[x_shape[-1] if data_format == "NHWC" else x_shape[1]]).astype( np.float32) vm = np.random.rand( *[x_shape[-1] if data_format == "NHWC" else x_shape[1]]).astype( np.float32) vv = np.random.rand( *[x_shape[-1] if data_format == "NHWC" else x_shape[1]]).astype( np.float32) with tf.Session() as sess: vy, = sess.run([y], { x: vx, scale: vs, bias: vb, mean: vm, variance: vv }) graph = TensorFlowConverter(sess, batch_size=2).convert( [x, scale, bias, mean, variance], [y]) generate_kernel_test_case( description=f"[TensorFlow] FusedBatchNorm {description}", graph=graph, inputs={ graph.inputs[0]: vx, graph.inputs[1]: vs, graph.inputs[2]: vb, graph.inputs[3]: vm, graph.inputs[4]: vv }, expected={graph.outputs[0]: vy}, )
def _fused_batch_norm_training(): return nn.fused_batch_norm( inputs, gamma, beta, mean=self.moving_mean, variance=_maybe_add_or_remove_bessels_correction( self.moving_variance, remove=False), epsilon=self.epsilon, is_training=True, data_format=self._data_format, exponential_avg_factor=exponential_avg_factor)
def model(x, y, z): scale = gen_array_ops.broadcast_to(z, shape=[65536]) offset = scale b_mean, b_var = nn.moments(x, [0, 1, 2], name='moments') a = nn.fused_batch_norm(x, scale, offset, b_mean, b_var, 1e-3, is_training=False, name="a") b = nn.fused_batch_norm(y, scale, offset, b_mean, b_var, 1e-3, is_training=False, name="b") return a[0] + b[0]
def _fused_batch_norm_training(): outputs, mean, variance = nn.fused_batch_norm( inputs, gamma, beta, epsilon=epsilon, data_format=data_format) if renorm: moving_inv = math_ops.rsqrt(moving_variance + epsilon) r = tf.stop_gradient(tf.clip_by_value(tf.sqrt(variance + epsilon) * moving_inv, 1/RMAX, RMAX)) d = tf.stop_gradient(tf.clip_by_value((mean - moving_mean) * moving_inv, -DMAX, DMAX)) outputs = outputs * r + d return outputs, mean, variance
def _testLearning(self, use_gradient_checker): channel = 3 x_shape = [2, 2, 6, channel] scale_shape = [channel] x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) offset_val = np.random.random_sample(scale_shape).astype(np.float32) mean_val = np.random.random_sample(scale_shape).astype(np.float32) var_val = np.random.random_sample(scale_shape).astype(np.float32) data_format = "NHWC" with self.test_session() as sess, self.test_scope(): # To avoid constant folding t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") offset = array_ops.placeholder( np.float32, shape=scale_shape, name="offset") epsilon = 0.001 y, mean, var = nn.fused_batch_norm( t_val, scale, offset, mean=None, variance=None, epsilon=epsilon, data_format=data_format, is_training=True) # Check gradient. if use_gradient_checker: err = gradient_checker.compute_gradient_error( t_val, x_shape, y, x_shape, extra_feed_dict={ t_val: x_val, scale: scale_val, offset: offset_val }) self.assertLess(err, 1e-3) y_val, mean_val, var_val = sess.run( [y, mean, var], {t_val: x_val, scale: scale_val, offset: offset_val}) y_ref, mean_ref, var_ref = self._reference_training( x_val, scale_val, offset_val, epsilon, data_format) self.assertAllClose(mean_val, mean_ref, atol=1e-3) self.assertAllClose(y_val, y_ref, atol=1e-3) self.assertAllClose(var_val, var_ref, atol=1e-3)
def FProp(self, theta, inputs, paddings=None): """Apply batch normalization. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dim]. paddings: The paddings tensor. Shaped [..., 1], with the same rank as the input tensor. Returns: Output after applying batch normalization, with the same shape as 'inputs'. """ p = self.params if paddings is None: paddings = self._GetDefaultPaddings(inputs) with tf.name_scope(p.name): norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments( theta, inputs, paddings) with tf.control_dependencies([ py_utils.assert_greater_equal( norm_variance, tf.zeros_like(norm_variance)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_mean)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_variance)), ]): if p.use_fused_batch_norm_for_eval and self.do_eval: bn_output, _, _ = nn.fused_batch_norm(inputs, gamma, beta, norm_mean, norm_variance, self._epsilon, is_training=False) else: bn_output = tf.nn.batch_normalization( inputs, norm_mean, norm_variance, beta, gamma, self._epsilon) if p.set_padded_output_to_zero: bn_output *= 1.0 - paddings return bn_output
def testInference(self, data_format): channel = 3 x_shape = [2, 2, 6, channel] scale_shape = [channel] x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) offset_val = np.random.random_sample(scale_shape).astype(np.float32) epsilon = 0.001 exponential_avg_factor = 1.0 data_format_src = "NHWC" y_ref, mean_ref, var_ref, _ = self._reference_training( x_val, scale_val, offset_val, None, None, epsilon, exponential_avg_factor, data_format_src) with self.session() as sess, self.test_scope(): # To avoid constant folding x_val_converted = test_utils.ConvertBetweenDataFormats( x_val, data_format_src, data_format) y_ref_converted = test_utils.ConvertBetweenDataFormats( y_ref, data_format_src, data_format) t_val = array_ops.placeholder(np.float32, shape=x_val_converted.shape, name="x") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset") y, mean, variance = nn.fused_batch_norm(t_val, scale, offset, mean=mean_ref, variance=var_ref, epsilon=epsilon, data_format=data_format, is_training=False) y_val, _, _ = sess.run([y, mean, variance], { t_val: x_val_converted, scale: scale_val, offset: offset_val }) self.assertAllClose(y_val, y_ref_converted, atol=1e-3)
def testBatchNormalizeFused(self): x = array_ops.placeholder(np.float32, [4, 64, 64, 4], name="a") with ops.device("/device:IPU:0"): with variable_scope.variable_scope("", use_resource=True): beta = variable_scope.get_variable( "x", dtype=np.float32, shape=[4], initializer=init_ops.constant_initializer(0.0)) gamma = variable_scope.get_variable( "y", dtype=np.float32, shape=[4], initializer=init_ops.constant_initializer(1.0)) b_mean, b_var = nn.moments(x, [0, 1, 2], name='moments') normed = nn.fused_batch_norm(x, gamma, beta, b_mean, b_var, is_training=False) with ops.device('cpu'): report = gen_ipu_ops.ipu_event_trace() tu.configure_ipu_system() with tu.ipu_session() as sess: sess.run(report) sess.run(variables.global_variables_initializer()) result, _, _ = sess.run(normed, {x: np.zeros([4, 64, 64, 4])}) self.assertAllClose(result, np.zeros([4, 64, 64, 4])) rep = sess.run(report) s = tu.extract_all_strings_from_event_trace(rep) cs = tu.get_compute_sets_from_report(s) bl = ['*convert*/Cast*'] self.assertTrue(tu.check_compute_sets_not_in_blacklist(cs, bl))
def call(self, inputs, training=None): training = self._get_training_value(training) if self.subdivisions <= 1 or self.subdivisions is None: return super().call(inputs, training=training) else: if self.renorm is False and training is False and self.fused: # outputs = self._fused_batch_norm(inputs, training=False) beta = self.beta if self.center else self._beta_const gamma = self.gamma if self.scale else self._gamma_const outputs, mean, variance = nn.fused_batch_norm( inputs, gamma, beta, mean=self.moving_mean, variance=self.moving_variance, epsilon=self.epsilon, is_training=False, data_format=self._data_format) return outputs return self._subdiv_batch_norm(inputs, training=training)
def my_graph(a): beta = variable_scope.get_variable( "x", dtype=np.float16, shape=[4], initializer=init_ops.constant_initializer(0.0)) gamma = variable_scope.get_variable( "y", dtype=np.float16, shape=[4], initializer=init_ops.constant_initializer(1.0)) b_mean, b_var = nn.moments(a, [0, 1, 2], name='moments') normed = nn.fused_batch_norm(a, gamma, beta, b_mean, b_var, is_training=False) return normed
def call(self, inputs): # Compute the axes along which to reduce the mean / variance input_shape = inputs.shape ndims = len(input_shape) # Broadcasting only necessary for norm where the axis is not just # the last dimension broadcast_shape = [1] * ndims for dim in self.axis: broadcast_shape[dim] = input_shape.dims[dim].value def _broadcast(v): if (v is not None and len(v.shape) != ndims and self.axis != [ndims - 1]): return array_ops.reshape(v, broadcast_shape) return v if not self._fused: input_dtype = inputs.dtype if input_dtype in ('float16', 'bfloat16') and self.dtype == 'float32': # If mixed precision is used, cast inputs to float32 so that this is at # least as numerically stable as the fused version. inputs = math_ops.cast(inputs, 'float32') # Calculate the moments on the last axis (layer activations). mean, variance = nn.moments(inputs, self.axis, keep_dims=True) scale, offset = _broadcast(self.gamma), _broadcast(self.beta) # Compute layer normalization using the batch_normalization function. outputs = nn.batch_normalization( inputs, mean, variance, offset=offset, scale=scale, variance_epsilon=self.epsilon) outputs = math_ops.cast(outputs, input_dtype) else: # Collapse dims before self.axis, and dims in self.axis pre_dim, in_dim = (1, 1) axis = sorted(self.axis) tensor_shape = array_ops.shape(inputs) for dim in range(0, ndims): dim_tensor = tensor_shape[dim] if dim < axis[0]: pre_dim = pre_dim * dim_tensor else: assert dim in axis in_dim = in_dim * dim_tensor squeezed_shape = [1, pre_dim, in_dim, 1] # This fused operation requires reshaped inputs to be NCHW. data_format = 'NCHW' inputs = array_ops.reshape(inputs, squeezed_shape) def _set_const_tensor(val, dtype, shape): return array_ops.fill(shape, constant_op.constant(val, dtype=dtype)) # self.gamma and self.beta have the wrong shape for fused_batch_norm, so # we cannot pass them as the scale and offset parameters. Therefore, we # create two constant tensors in correct shapes for fused_batch_norm and # later construct a separate calculation on the scale and offset. scale = _set_const_tensor(1.0, self.dtype, [pre_dim]) offset = _set_const_tensor(0.0, self.dtype, [pre_dim]) # Compute layer normalization using the fused_batch_norm function. outputs, _, _ = nn.fused_batch_norm( inputs, scale=scale, offset=offset, epsilon=self.epsilon, data_format=data_format) outputs = array_ops.reshape(outputs, tensor_shape) scale, offset = _broadcast(self.gamma), _broadcast(self.beta) if scale is not None: outputs = outputs * math_ops.cast(scale, outputs.dtype) if offset is not None: outputs = outputs + math_ops.cast(offset, outputs.dtype) # If some components of the shape got lost due to adjustments, fix that. outputs.set_shape(input_shape) return outputs
def fused_instance_norm(inputs, center=True, scale=True, epsilon=1e-6, activation_fn=None, param_initializers=None, reuse=None, variables_collections=None, outputs_collections=None, trainable=True, data_format=DATA_FORMAT_NHWC, scope=None): """Functional interface for the instance normalization layer. Reference: https://arxiv.org/abs/1607.08022. "Instance Normalization: The Missing Ingredient for Fast Stylization" Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky Args: inputs: A tensor with 2 or more dimensions, where the first dimension has `batch_size`. The normalization is over all but the last dimension if `data_format` is `NHWC` and the second dimension if `data_format` is `NCHW`. center: If True, add offset of `beta` to normalized tensor. If False, `beta` is ignored. scale: If True, multiply by `gamma`. If False, `gamma` is not used. When the next layer is linear (also e.g. `nn.relu`), this can be disabled since the scaling can be done by the next layer. epsilon: Small float added to variance to avoid dividing by zero. activation_fn: Activation function, default set to None to skip it and maintain a linear activation. param_initializers: Optional initializers for beta, gamma, moving mean and moving variance. reuse: Whether or not the layer and its variables should be reused. To be able to reuse the layer scope must be given. variables_collections: Optional collections for the variables. outputs_collections: Collections to add the outputs. trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). data_format: A string. `NHWC` (default) and `NCHW` are supported. scope: Optional scope for `variable_scope`. Returns: A `Tensor` representing the output of the operation. Raises: ValueError: If `data_format` is neither `NHWC` nor `NCHW`. ValueError: If the rank of `inputs` is undefined. ValueError: If rank or channels dimension of `inputs` is undefined. """ inputs = ops.convert_to_tensor(inputs) inputs_shape = inputs.shape inputs_rank = inputs.shape.ndims if inputs_rank is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): raise ValueError('data_format has to be either NCHW or NHWC.') with variable_scope.variable_scope( scope, 'InstanceNorm', [inputs], reuse=reuse) as sc: if data_format == DATA_FORMAT_NCHW: reduction_axis = 1 # For NCHW format, rather than relying on implicit broadcasting, we # explicitly reshape the params to params_shape_broadcast when computing # the moments and the batch normalization. params_shape_broadcast = list( [1, inputs_shape[1].value] + [1 for _ in range(2, inputs_rank)]) else: reduction_axis = inputs_rank - 1 params_shape_broadcast = None moments_axes = list(range(inputs_rank)) del moments_axes[reduction_axis] del moments_axes[0] params_shape = inputs_shape[reduction_axis:reduction_axis + 1] if not params_shape.is_fully_defined(): raise ValueError('Inputs %s has undefined channels dimension %s.' % ( inputs.name, params_shape)) # Allocate parameters for the beta and gamma of the normalization. beta, gamma = None, None dtype = inputs.dtype.base_dtype if param_initializers is None: param_initializers = {} if center: beta_collections = utils.get_variable_collections( variables_collections, 'beta') beta_initializer = param_initializers.get( 'beta', init_ops.zeros_initializer()) beta = variables.model_variable('beta', shape=params_shape, dtype=dtype, initializer=beta_initializer, collections=beta_collections, trainable=trainable) if params_shape_broadcast: beta = array_ops.reshape(beta, params_shape_broadcast) if scale: gamma_collections = utils.get_variable_collections( variables_collections, 'gamma') gamma_initializer = param_initializers.get( 'gamma', init_ops.ones_initializer()) gamma = variables.model_variable('gamma', shape=params_shape, dtype=dtype, initializer=gamma_initializer, collections=gamma_collections, trainable=trainable) if params_shape_broadcast: gamma = array_ops.reshape(gamma, params_shape_broadcast) if data_format == DATA_FORMAT_NHWC: inputs = array_ops.transpose(inputs, list(range(1, reduction_axis)) + [0, reduction_axis]) if data_format == DATA_FORMAT_NCHW: inputs = array_ops.transpose(inputs, list(range(2, inputs_rank)) + [0, reduction_axis]) hw, n, c = inputs.shape.as_list()[:-2], inputs.shape[-2].value, inputs.shape[-1].value inputs = array_ops.reshape(inputs, [1] + hw + [n * c]) if inputs.shape.ndims != 4: # combine all the spatial dimensions into only two, e.g. [D, H, W] -> [DH, W] if inputs.shape.ndims > 4: inputs_ndims4_shape = [1, hw[0], -1, n * c] else: inputs_ndims4_shape = [1, 1, -1, n * c] inputs = array_ops.reshape(inputs, inputs_ndims4_shape) beta = array_ops.reshape(array_ops.tile(beta[None, :], [n, 1]), [-1]) gamma = array_ops.reshape(array_ops.tile(gamma[None, :], [n, 1]), [-1]) outputs, _, _ = nn.fused_batch_norm( inputs, gamma, beta, epsilon=epsilon, data_format=DATA_FORMAT_NHWC, name='instancenorm') outputs = array_ops.reshape(outputs, hw + [n, c]) if data_format == DATA_FORMAT_NHWC: outputs = array_ops.transpose(outputs, [inputs_rank - 2] + list(range(inputs_rank - 2)) + [inputs_rank - 1]) if data_format == DATA_FORMAT_NCHW: outputs = array_ops.transpose(outputs, [inputs_rank - 2, inputs_rank - 1] + list(range(inputs_rank - 2))) # if data_format == DATA_FORMAT_NHWC: # inputs = array_ops.transpose(inputs, [0, reduction_axis] + list(range(1, reduction_axis))) # inputs_nchw_shape = inputs.shape # inputs = array_ops.reshape(inputs, [1, -1] + inputs_nchw_shape.as_list()[2:]) # if inputs.shape.ndims != 4: # # combine all the spatial dimensions into only two, e.g. [D, H, W] -> [DH, W] # if inputs.shape.ndims > 4: # inputs_ndims4_shape = inputs.shape.as_list()[:2] + [-1, inputs_nchw_shape.as_list()[-1]] # else: # inputs_ndims4_shape = inputs.shape.as_list()[:2] + [1, -1] # inputs = array_ops.reshape(inputs, inputs_ndims4_shape) # beta = array_ops.reshape(array_ops.tile(beta[None, :], [inputs_nchw_shape[0].value, 1]), [-1]) # gamma = array_ops.reshape(array_ops.tile(gamma[None, :], [inputs_nchw_shape[0].value, 1]), [-1]) # # outputs, _, _ = nn.fused_batch_norm( # inputs, gamma, beta, epsilon=epsilon, # data_format=DATA_FORMAT_NCHW, name='instancenorm') # # outputs = array_ops.reshape(outputs, inputs_nchw_shape) # if data_format == DATA_FORMAT_NHWC: # outputs = array_ops.transpose(outputs, [0] + list(range(2, inputs_rank)) + [1]) if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.name, outputs)
def _testLearning(self, use_gradient_checker, data_format, exponential_avg_factor): channel = 3 x_shape = [2, 2, 6, channel] scale_shape = [channel] x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) offset_val = np.random.random_sample(scale_shape).astype(np.float32) mean_val = np.random.random_sample(scale_shape).astype(np.float32) var_val_corr = np.random.random_sample(scale_shape).astype(np.float32) epsilon = 0.001 data_format_src = "NHWC" # When in training mode, fused_batchnorm applies an implicit Bessel's # correction. So we have to use the corrected variance here, as well. y_ref, mean_ref, _, var_ref_corr = self._reference_training( x_val, scale_val, offset_val, mean_val, var_val_corr, epsilon, exponential_avg_factor, data_format_src) with self.session() as sess, self.test_scope(): # To avoid constant folding x_val_converted = test_utils.ConvertBetweenDataFormats( x_val, data_format_src, data_format) y_ref_converted = test_utils.ConvertBetweenDataFormats( y_ref, data_format_src, data_format) t_val = array_ops.placeholder(np.float32, shape=x_val_converted.shape, name="x") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset") if exponential_avg_factor == 1.0: old_mean = None old_var = None else: old_mean = array_ops.placeholder(np.float32, shape=scale_shape, name="old_mean") old_var = array_ops.placeholder(np.float32, shape=scale_shape, name="old_var") y, mean, var = nn.fused_batch_norm( t_val, scale, offset, mean=old_mean, variance=old_var, epsilon=epsilon, exponential_avg_factor=exponential_avg_factor, data_format=data_format, is_training=True) if exponential_avg_factor == 1.0: feed_dict = { t_val: x_val_converted, scale: scale_val, offset: offset_val, } else: feed_dict = { t_val: x_val_converted, scale: scale_val, offset: offset_val, old_mean: mean_val, old_var: var_val_corr } # Check gradient. if use_gradient_checker: err = gradient_checker.compute_gradient_error( t_val, x_val_converted.shape, y, x_val_converted.shape, extra_feed_dict=feed_dict) self.assertLess(err, 1e-3) y_tf, mean_tf, var_tf = sess.run([y, mean, var], feed_dict) self.assertAllClose(y_tf, y_ref_converted, atol=1e-3) self.assertAllClose(mean_tf, mean_ref, atol=1e-3) self.assertAllClose(var_tf, var_ref_corr, atol=1e-3)
def _testLearning(self, use_gradient_checker, data_format): channel = 3 x_shape = [2, 2, 6, channel] scale_shape = [channel] x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) offset_val = np.random.random_sample(scale_shape).astype(np.float32) mean_val = np.random.random_sample(scale_shape).astype(np.float32) var_val = np.random.random_sample(scale_shape).astype(np.float32) epsilon = 0.001 data_format_src = "NHWC" y_ref, mean_ref, var_ref = self._reference_training( x_val, scale_val, offset_val, epsilon, data_format_src) # TODO(b/110530713): Support data format HWCN on GPU if self.device == "XLA_GPU" and data_format == "HWCN": self.skipTest("GPU does not support data format HWCN.") with self.test_session() as sess, self.test_scope(): # To avoid constant folding x_val_converted = test_utils.ConvertBetweenDataFormats( x_val, data_format_src, data_format) y_ref_converted = test_utils.ConvertBetweenDataFormats( y_ref, data_format_src, data_format) t_val = array_ops.placeholder( np.float32, shape=x_val_converted.shape, name="x") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") offset = array_ops.placeholder( np.float32, shape=scale_shape, name="offset") y, mean, var = nn.fused_batch_norm( t_val, scale, offset, mean=None, variance=None, epsilon=epsilon, data_format=data_format, is_training=True) # Check gradient. if use_gradient_checker: err = gradient_checker.compute_gradient_error( t_val, x_val_converted.shape, y, x_val_converted.shape, extra_feed_dict={ t_val: x_val_converted, scale: scale_val, offset: offset_val }) self.assertLess(err, 1e-3) y_val, mean_val, var_val = sess.run([y, mean, var], { t_val: x_val_converted, scale: scale_val, offset: offset_val }) self.assertAllClose(mean_val, mean_ref, atol=1e-3) self.assertAllClose(y_val, y_ref_converted, atol=1e-3) self.assertAllClose(var_val, var_ref, atol=1e-3)
def fused_layer_norm(inputs, center=True, scale=True, activation_fn=None, reuse=None, variables_collections=None, outputs_collections=None, trainable=True, begin_norm_axis=1, begin_params_axis=-1, scope=None, use_fused_batch_norm=False): with tf.compat.v1.variable_scope(scope, 'LayerNorm', [inputs], reuse=reuse) as sc: inputs = ops.convert_to_tensor(inputs) inputs_shape = inputs.shape inputs_rank = inputs_shape.ndims if inputs_rank is None: raise ValueError('Inputs %s has undefined rank.' % inputs.name) dtype = inputs.dtype.base_dtype if begin_norm_axis < 0: begin_norm_axis = inputs_rank + begin_norm_axis if begin_params_axis >= inputs_rank or begin_norm_axis >= inputs_rank: raise ValueError('begin_params_axis (%d) and begin_norm_axis (%d) ' 'must be < rank(inputs) (%d)' % (begin_params_axis, begin_norm_axis, inputs_rank)) params_shape = inputs_shape[begin_params_axis:] if not params_shape.is_fully_defined(): raise ValueError( 'Inputs %s: shape(inputs)[%s:] is not fully defined: %s' % (inputs.name, begin_params_axis, inputs_shape)) # Allocate parameters for the beta and gamma of the normalization. beta, gamma = None, None if center: beta_collections = utils.get_variable_collections( variables_collections, 'beta') beta = variables.model_variable( 'beta', shape=params_shape, dtype=dtype, initializer=init_ops.zeros_initializer(), collections=beta_collections, trainable=trainable) if scale: gamma_collections = utils.get_variable_collections( variables_collections, 'gamma') gamma = variables.model_variable( 'gamma', shape=params_shape, dtype=dtype, initializer=init_ops.ones_initializer(), collections=gamma_collections, trainable=trainable) if use_fused_batch_norm: # get static TensorShape if fully defined, # otherwise retrieve shape tensor norm_shape = inputs.shape[begin_norm_axis:] if norm_shape.is_fully_defined(): bn_shape = [1, -1, 1, numpy.prod(norm_shape.as_list())] else: norm_shape = tf.shape(input=inputs)[begin_norm_axis:] bn_shape = [1, -1, 1, tf.reduce_prod(input_tensor=norm_shape)] if inputs.get_shape().is_fully_defined(): outputs_shape = inputs.get_shape() else: outputs_shape = tf.shape(input=inputs) inputs = array_ops.reshape(inputs, bn_shape) if inputs.get_shape().is_fully_defined(): # static inputs TensorShape fully defined after reshape. ones = array_ops.ones(inputs.get_shape()[1], dtype=dtypes.float32) zeros = array_ops.zeros(inputs.get_shape()[1], dtype=dtypes.float32) else: # static inputs TensorShape NOT fully defined after reshape. # must use dynamic shape, which means these input tensors # have to be created at runtime, which causes a slowdown. scale_shape = tf.shape(input=inputs)[1] ones = array_ops.ones(scale_shape, dtype=dtypes.float32) zeros = array_ops.zeros(scale_shape, dtype=dtypes.float32) outputs, mean, variance = nn.fused_batch_norm(inputs, ones, zeros, epsilon=1e-4, data_format="NCHW") outputs = array_ops.reshape(outputs, outputs_shape) if center and scale: outputs = outputs * gamma + beta elif center: outputs = outputs + beta elif scale: outputs = outputs * gamma else: # Calculate the moments on the last axis (layer activations). norm_axes = list(range(begin_norm_axis, inputs_rank)) mean, variance = nn.moments(inputs, norm_axes, keep_dims=True) # Compute layer normalization using the batch_normalization function. variance_epsilon = 1e-4 outputs = nn.batch_normalization(inputs, mean, variance, offset=beta, scale=gamma, variance_epsilon=variance_epsilon) outputs.set_shape(inputs_shape) if activation_fn is not None: outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.name, outputs)