Exemple #1
0
    def test_training_batchnorm(self):
        """ Test BNUtils get_training() with both fused and non fused batchnorms, with all three training modes """

        tf.compat.v1.reset_default_graph()

        # Model with fused batchnorms
        _ = keras_model_functional()
        fused_bn_training_true_op = tf.compat.v1.get_default_graph().get_operation_by_name('batch_normalization/FusedBatchNormV3')
        self.assertTrue(BNUtils.get_training(fused_bn_training_true_op))
        self.assertTrue(isinstance(BNUtils.get_training(fused_bn_training_true_op), bool))

        fused_bn_training_tensor_op = tf.compat.v1.get_default_graph().get_operation_by_name('scope_1/batch_normalization_1/cond/'
                                                                                   'FusedBatchNormV3_1')
        training_tensor = tf.compat.v1.get_default_graph().get_tensor_by_name('is_training:0')
        self.assertEqual(BNUtils.get_training(fused_bn_training_tensor_op), training_tensor)

        fused_bn_training_false_op = tf.compat.v1.get_default_graph().get_operation_by_name('scope_1/batch_normalization_2/'
                                                                                  'FusedBatchNormV3')
        self.assertFalse(BNUtils.get_training(fused_bn_training_false_op))

        tf.compat.v1.reset_default_graph()

        # Model with non fused batchnorms
        _ = keras_model_functional_with_non_fused_batchnorms()
        bn_training_true_op = tf.compat.v1.get_default_graph().get_operation_by_name('batch_normalization/batchnorm/mul_1')
        self.assertTrue(BNUtils.get_training(bn_training_true_op))
        self.assertTrue(isinstance(BNUtils.get_training(bn_training_true_op), bool))

        bn_training_tensor_op = tf.compat.v1.get_default_graph().get_operation_by_name('scope_1/batch_normalization_1/batchnorm/'
                                                                             'mul_1')
        training_tensor = tf.compat.v1.get_default_graph().get_tensor_by_name('is_training:0')
        self.assertEqual(BNUtils.get_training(bn_training_tensor_op), training_tensor)

        bn_training_false_op = tf.compat.v1.get_default_graph().get_operation_by_name('scope_1/batch_normalization_2/batchnorm/'
                                                                            'mul_1')
        self.assertFalse(BNUtils.get_training(bn_training_false_op))

        tf.compat.v1.reset_default_graph()
Exemple #2
0
def reduce_batchnorm(sess: tf.compat.v1.Session,
                     op_tensor_tuple: Tuple[Op, List[tf.Tensor]], op_mask) -> (str, tf.Operation, tf.Operation):
    """
    Fused and non fused batchnorm module reducer
    :param sess: current tf.compat.v1.Session
    :param op_tensor_tuple: tuple containing the op to reduce, and a list of input tensors to the op
    :param op_mask: Mask containing information on input and output channels to winnow
    """

    beta_product = op_tensor_tuple[0].get_param_product('beta')
    if beta_product:
        use_beta = True
        reduced_beta_init = tf.constant_initializer(_get_reduced_params(sess=sess,
                                                                        product=beta_product,
                                                                        mask=op_mask,
                                                                        input_dim=0,
                                                                        output_dim=None),
                                                    verify_shape=True)
    else:
        use_beta = False
        reduced_beta_init = 'zeros'

    gamma_product = op_tensor_tuple[0].get_param_product('gamma')
    if gamma_product:
        use_gamma = True
        reduced_gamma_init = tf.constant_initializer(_get_reduced_params(sess=sess,
                                                                         product=gamma_product,
                                                                         mask=op_mask,
                                                                         input_dim=0,
                                                                         output_dim=None),
                                                     verify_shape=True)
    else:
        use_gamma = False
        reduced_gamma_init = 'ones'

    moving_mean_product = op_tensor_tuple[0].get_param_product('moving_mean')
    reduced_mov_mean_init = tf.constant_initializer(_get_reduced_params(sess=sess,
                                                                        product=moving_mean_product,
                                                                        mask=op_mask,
                                                                        input_dim=0,
                                                                        output_dim=None),
                                                    verify_shape=True)
    moving_variance_product = op_tensor_tuple[0].get_param_product('moving_variance')
    reduced_mov_variance_init = tf.constant_initializer(_get_reduced_params(sess=sess,
                                                                            product=moving_variance_product,
                                                                            mask=op_mask,
                                                                            input_dim=0,
                                                                            output_dim=None),
                                                        verify_shape=True)

    name = "reduced_" + op_tensor_tuple[0].dotted_name
    # Get training attribute
    # This will either be True, False, or a string representing a training_placeholder the original BN was using
    training = BNUtils.get_training(op_tensor_tuple[0].get_module())
    assert training is not None
    is_fused = op_tensor_tuple[0].type == 'FusedBatchNormV3'
    epsilon = BNUtils.get_epsilon(op_tensor_tuple[0].get_module())
    momentum = BNUtils.get_momentum(op_tensor_tuple[0].get_module())
    if momentum is not None:
        new_tensor = tf.keras.layers.BatchNormalization(center=use_beta,
                                                        scale=use_gamma,
                                                        epsilon=epsilon,
                                                        momentum=momentum,
                                                        beta_initializer=reduced_beta_init,
                                                        gamma_initializer=reduced_gamma_init,
                                                        moving_mean_initializer=reduced_mov_mean_init,
                                                        moving_variance_initializer=reduced_mov_variance_init,
                                                        fused=is_fused,
                                                        name=name)(op_tensor_tuple[1][0], training=training)
    else:
        new_tensor = tf.keras.layers.BatchNormalization(center=use_beta,
                                                        scale=use_gamma,
                                                        epsilon=epsilon,
                                                        beta_initializer=reduced_beta_init,
                                                        gamma_initializer=reduced_gamma_init,
                                                        moving_mean_initializer=reduced_mov_mean_init,
                                                        moving_variance_initializer=reduced_mov_variance_init,
                                                        fused=is_fused,
                                                        name=name)(op_tensor_tuple[1][0], training=training)
    module = new_tensor.op.inputs[0].op

    return name, new_tensor.op, module