def _get_bn_params(sess: tf.compat.v1.Session,
                   bn: tf.Operation) -> libpymo.BNParams():
    """
    helper to populate BN params from given BN op, required for fold
    :param sess: tf.compat.v1.Session type
    :param bn: BatchNorm or a FusedBatch Norm op
    :return: bn_params
    """
    # make sure you define the session and graph scope before loading vars from graph.
    with sess.graph.as_default():
        # create BNParams type and populate
        bn_params = libpymo.BNParams()
        bn_params.beta = BNUtils.get_beta_as_numpy_data(sess, bn).reshape(-1)
        bn_params.gamma = BNUtils.get_gamma_as_numpy_data(sess, bn).reshape(-1)
        bn_params.runningMean = BNUtils.get_moving_mean_as_numpy_data(
            sess, bn).reshape(-1)
        bn_params.runningVar = BNUtils.get_moving_variance_as_numpy_data(
            sess, bn).reshape(-1)
        epsilon = BNUtils.get_epsilon(bn)
        var = BNUtils.get_moving_variance_as_numpy_data(sess, bn).reshape(-1)
        var_with_epsilon = var + epsilon
        sigma = np.sqrt(var_with_epsilon)
        # sigma = tf.sqrt(BNUtils.get_moving_variance_as_numpy_data(sess, bn).reshape(-1) + epsilon)
        bn_params.runningVar = sigma  # sess.run(sigma).reshape(-1)

    return bn_params
Exemple #2
0
    def test_with_slim_bn_op(self):
        """
        Test with Tf Slim BN op
        :return:
        """
        tf.compat.v1.reset_default_graph()
        sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph())
        inp = tf.compat.v1.placeholder(tf.float32, [1, 32, 32, 3])
        net = slim.conv2d(inp, 32, [3, 3])
        _ = slim.batch_norm(net, decay=.7, epsilon=.65, is_training=True)

        init = tf.compat.v1.global_variables_initializer()
        sess.run(init)
        with sess.graph.as_default():
            bn_op = sess.graph.get_operation_by_name('BatchNorm/FusedBatchNormV3')
            moving_mean = BNUtils.get_moving_mean_as_numpy_data(sess, bn_op)
            moving_var = BNUtils.get_moving_variance_as_numpy_data(sess, bn_op)
            beta = BNUtils.get_beta_as_numpy_data(sess, bn_op)
            gamma = BNUtils.get_gamma_as_numpy_data(sess, bn_op)

        # check the values read are equal to init values
        expected_beta = np.zeros_like(beta)
        expected_gamma = np.ones_like(gamma)
        expected_mean = np.zeros_like(moving_mean)
        expected_variance = np.ones_like(moving_var)

        self.assertTrue(np.allclose(expected_beta, beta))
        self.assertTrue(np.allclose(expected_gamma, gamma))
        self.assertTrue(np.allclose(expected_mean, moving_mean))
        self.assertTrue(np.allclose(expected_variance, moving_var))
Exemple #3
0
def get_bn_params_aimet_api(sess, bn_op):
    """
    Helper to get param values from BN layer using AIMET api(s)
    :param bn_op: BN layer
    :return: beta, gamma, mean and vairance values extracted from BN layer
    """
    beta = BNUtils.get_beta_as_numpy_data(sess, bn_op)
    gamma = BNUtils.get_gamma_as_numpy_data(sess, bn_op)
    moving_mean = BNUtils.get_moving_mean_as_numpy_data(sess, bn_op)
    moving_var = BNUtils.get_moving_variance_as_numpy_data(sess, bn_op)

    return [beta, gamma, moving_mean, moving_var]