예제 #1
0
def _get_bn_params(sess: tf.compat.v1.Session,
                   bn: tf.Operation) -> libpymo.BNParams():
    """
    helper to populate BN params from given BN op, required for fold
    :param sess: tf.compat.v1.Session type
    :param bn: BatchNorm or a FusedBatch Norm op
    :return: bn_params
    """
    # make sure you define the session and graph scope before loading vars from graph.
    with sess.graph.as_default():
        # create BNParams type and populate
        bn_params = libpymo.BNParams()
        bn_params.beta = BNUtils.get_beta_as_numpy_data(sess, bn).reshape(-1)
        bn_params.gamma = BNUtils.get_gamma_as_numpy_data(sess, bn).reshape(-1)
        bn_params.runningMean = BNUtils.get_moving_mean_as_numpy_data(
            sess, bn).reshape(-1)
        bn_params.runningVar = BNUtils.get_moving_variance_as_numpy_data(
            sess, bn).reshape(-1)
        epsilon = BNUtils.get_epsilon(bn)
        var = BNUtils.get_moving_variance_as_numpy_data(sess, bn).reshape(-1)
        var_with_epsilon = var + epsilon
        sigma = np.sqrt(var_with_epsilon)
        # sigma = tf.sqrt(BNUtils.get_moving_variance_as_numpy_data(sess, bn).reshape(-1) + epsilon)
        bn_params.runningVar = sigma  # sess.run(sigma).reshape(-1)

    return bn_params
예제 #2
0
def reduce_batchnorm(sess: tf.compat.v1.Session,
                     op_tensor_tuple: Tuple[Op, List[tf.Tensor]], op_mask) -> (str, tf.Operation, tf.Operation):
    """
    Fused and non fused batchnorm module reducer
    :param sess: current tf.compat.v1.Session
    :param op_tensor_tuple: tuple containing the op to reduce, and a list of input tensors to the op
    :param op_mask: Mask containing information on input and output channels to winnow
    """

    beta_product = op_tensor_tuple[0].get_param_product('beta')
    if beta_product:
        use_beta = True
        reduced_beta_init = tf.constant_initializer(_get_reduced_params(sess=sess,
                                                                        product=beta_product,
                                                                        mask=op_mask,
                                                                        input_dim=0,
                                                                        output_dim=None),
                                                    verify_shape=True)
    else:
        use_beta = False
        reduced_beta_init = 'zeros'

    gamma_product = op_tensor_tuple[0].get_param_product('gamma')
    if gamma_product:
        use_gamma = True
        reduced_gamma_init = tf.constant_initializer(_get_reduced_params(sess=sess,
                                                                         product=gamma_product,
                                                                         mask=op_mask,
                                                                         input_dim=0,
                                                                         output_dim=None),
                                                     verify_shape=True)
    else:
        use_gamma = False
        reduced_gamma_init = 'ones'

    moving_mean_product = op_tensor_tuple[0].get_param_product('moving_mean')
    reduced_mov_mean_init = tf.constant_initializer(_get_reduced_params(sess=sess,
                                                                        product=moving_mean_product,
                                                                        mask=op_mask,
                                                                        input_dim=0,
                                                                        output_dim=None),
                                                    verify_shape=True)
    moving_variance_product = op_tensor_tuple[0].get_param_product('moving_variance')
    reduced_mov_variance_init = tf.constant_initializer(_get_reduced_params(sess=sess,
                                                                            product=moving_variance_product,
                                                                            mask=op_mask,
                                                                            input_dim=0,
                                                                            output_dim=None),
                                                        verify_shape=True)

    name = "reduced_" + op_tensor_tuple[0].dotted_name
    # Get training attribute
    # This will either be True, False, or a string representing a training_placeholder the original BN was using
    training = BNUtils.get_training(op_tensor_tuple[0].get_module())
    assert training is not None
    is_fused = op_tensor_tuple[0].type == 'FusedBatchNormV3'
    epsilon = BNUtils.get_epsilon(op_tensor_tuple[0].get_module())
    momentum = BNUtils.get_momentum(op_tensor_tuple[0].get_module())
    if momentum is not None:
        new_tensor = tf.keras.layers.BatchNormalization(center=use_beta,
                                                        scale=use_gamma,
                                                        epsilon=epsilon,
                                                        momentum=momentum,
                                                        beta_initializer=reduced_beta_init,
                                                        gamma_initializer=reduced_gamma_init,
                                                        moving_mean_initializer=reduced_mov_mean_init,
                                                        moving_variance_initializer=reduced_mov_variance_init,
                                                        fused=is_fused,
                                                        name=name)(op_tensor_tuple[1][0], training=training)
    else:
        new_tensor = tf.keras.layers.BatchNormalization(center=use_beta,
                                                        scale=use_gamma,
                                                        epsilon=epsilon,
                                                        beta_initializer=reduced_beta_init,
                                                        gamma_initializer=reduced_gamma_init,
                                                        moving_mean_initializer=reduced_mov_mean_init,
                                                        moving_variance_initializer=reduced_mov_variance_init,
                                                        fused=is_fused,
                                                        name=name)(op_tensor_tuple[1][0], training=training)
    module = new_tensor.op.inputs[0].op

    return name, new_tensor.op, module