def test_training_batchnorm(self): """ Test BNUtils get_training() with both fused and non fused batchnorms, with all three training modes """ tf.compat.v1.reset_default_graph() # Model with fused batchnorms _ = keras_model_functional() fused_bn_training_true_op = tf.compat.v1.get_default_graph().get_operation_by_name('batch_normalization/FusedBatchNormV3') self.assertTrue(BNUtils.get_training(fused_bn_training_true_op)) self.assertTrue(isinstance(BNUtils.get_training(fused_bn_training_true_op), bool)) fused_bn_training_tensor_op = tf.compat.v1.get_default_graph().get_operation_by_name('scope_1/batch_normalization_1/cond/' 'FusedBatchNormV3_1') training_tensor = tf.compat.v1.get_default_graph().get_tensor_by_name('is_training:0') self.assertEqual(BNUtils.get_training(fused_bn_training_tensor_op), training_tensor) fused_bn_training_false_op = tf.compat.v1.get_default_graph().get_operation_by_name('scope_1/batch_normalization_2/' 'FusedBatchNormV3') self.assertFalse(BNUtils.get_training(fused_bn_training_false_op)) tf.compat.v1.reset_default_graph() # Model with non fused batchnorms _ = keras_model_functional_with_non_fused_batchnorms() bn_training_true_op = tf.compat.v1.get_default_graph().get_operation_by_name('batch_normalization/batchnorm/mul_1') self.assertTrue(BNUtils.get_training(bn_training_true_op)) self.assertTrue(isinstance(BNUtils.get_training(bn_training_true_op), bool)) bn_training_tensor_op = tf.compat.v1.get_default_graph().get_operation_by_name('scope_1/batch_normalization_1/batchnorm/' 'mul_1') training_tensor = tf.compat.v1.get_default_graph().get_tensor_by_name('is_training:0') self.assertEqual(BNUtils.get_training(bn_training_tensor_op), training_tensor) bn_training_false_op = tf.compat.v1.get_default_graph().get_operation_by_name('scope_1/batch_normalization_2/batchnorm/' 'mul_1') self.assertFalse(BNUtils.get_training(bn_training_false_op)) tf.compat.v1.reset_default_graph()
def reduce_batchnorm(sess: tf.compat.v1.Session, op_tensor_tuple: Tuple[Op, List[tf.Tensor]], op_mask) -> (str, tf.Operation, tf.Operation): """ Fused and non fused batchnorm module reducer :param sess: current tf.compat.v1.Session :param op_tensor_tuple: tuple containing the op to reduce, and a list of input tensors to the op :param op_mask: Mask containing information on input and output channels to winnow """ beta_product = op_tensor_tuple[0].get_param_product('beta') if beta_product: use_beta = True reduced_beta_init = tf.constant_initializer(_get_reduced_params(sess=sess, product=beta_product, mask=op_mask, input_dim=0, output_dim=None), verify_shape=True) else: use_beta = False reduced_beta_init = 'zeros' gamma_product = op_tensor_tuple[0].get_param_product('gamma') if gamma_product: use_gamma = True reduced_gamma_init = tf.constant_initializer(_get_reduced_params(sess=sess, product=gamma_product, mask=op_mask, input_dim=0, output_dim=None), verify_shape=True) else: use_gamma = False reduced_gamma_init = 'ones' moving_mean_product = op_tensor_tuple[0].get_param_product('moving_mean') reduced_mov_mean_init = tf.constant_initializer(_get_reduced_params(sess=sess, product=moving_mean_product, mask=op_mask, input_dim=0, output_dim=None), verify_shape=True) moving_variance_product = op_tensor_tuple[0].get_param_product('moving_variance') reduced_mov_variance_init = tf.constant_initializer(_get_reduced_params(sess=sess, product=moving_variance_product, mask=op_mask, input_dim=0, output_dim=None), verify_shape=True) name = "reduced_" + op_tensor_tuple[0].dotted_name # Get training attribute # This will either be True, False, or a string representing a training_placeholder the original BN was using training = BNUtils.get_training(op_tensor_tuple[0].get_module()) assert training is not None is_fused = op_tensor_tuple[0].type == 'FusedBatchNormV3' epsilon = BNUtils.get_epsilon(op_tensor_tuple[0].get_module()) momentum = BNUtils.get_momentum(op_tensor_tuple[0].get_module()) if momentum is not None: new_tensor = tf.keras.layers.BatchNormalization(center=use_beta, scale=use_gamma, epsilon=epsilon, momentum=momentum, beta_initializer=reduced_beta_init, gamma_initializer=reduced_gamma_init, moving_mean_initializer=reduced_mov_mean_init, moving_variance_initializer=reduced_mov_variance_init, fused=is_fused, name=name)(op_tensor_tuple[1][0], training=training) else: new_tensor = tf.keras.layers.BatchNormalization(center=use_beta, scale=use_gamma, epsilon=epsilon, beta_initializer=reduced_beta_init, gamma_initializer=reduced_gamma_init, moving_mean_initializer=reduced_mov_mean_init, moving_variance_initializer=reduced_mov_variance_init, fused=is_fused, name=name)(op_tensor_tuple[1][0], training=training) module = new_tensor.op.inputs[0].op return name, new_tensor.op, module