def _get_bn_params(sess: tf.compat.v1.Session, bn: tf.Operation) -> libpymo.BNParams(): """ helper to populate BN params from given BN op, required for fold :param sess: tf.compat.v1.Session type :param bn: BatchNorm or a FusedBatch Norm op :return: bn_params """ # make sure you define the session and graph scope before loading vars from graph. with sess.graph.as_default(): # create BNParams type and populate bn_params = libpymo.BNParams() bn_params.beta = BNUtils.get_beta_as_numpy_data(sess, bn).reshape(-1) bn_params.gamma = BNUtils.get_gamma_as_numpy_data(sess, bn).reshape(-1) bn_params.runningMean = BNUtils.get_moving_mean_as_numpy_data( sess, bn).reshape(-1) bn_params.runningVar = BNUtils.get_moving_variance_as_numpy_data( sess, bn).reshape(-1) epsilon = BNUtils.get_epsilon(bn) var = BNUtils.get_moving_variance_as_numpy_data(sess, bn).reshape(-1) var_with_epsilon = var + epsilon sigma = np.sqrt(var_with_epsilon) # sigma = tf.sqrt(BNUtils.get_moving_variance_as_numpy_data(sess, bn).reshape(-1) + epsilon) bn_params.runningVar = sigma # sess.run(sigma).reshape(-1) return bn_params
def reduce_batchnorm(sess: tf.compat.v1.Session, op_tensor_tuple: Tuple[Op, List[tf.Tensor]], op_mask) -> (str, tf.Operation, tf.Operation): """ Fused and non fused batchnorm module reducer :param sess: current tf.compat.v1.Session :param op_tensor_tuple: tuple containing the op to reduce, and a list of input tensors to the op :param op_mask: Mask containing information on input and output channels to winnow """ beta_product = op_tensor_tuple[0].get_param_product('beta') if beta_product: use_beta = True reduced_beta_init = tf.constant_initializer(_get_reduced_params(sess=sess, product=beta_product, mask=op_mask, input_dim=0, output_dim=None), verify_shape=True) else: use_beta = False reduced_beta_init = 'zeros' gamma_product = op_tensor_tuple[0].get_param_product('gamma') if gamma_product: use_gamma = True reduced_gamma_init = tf.constant_initializer(_get_reduced_params(sess=sess, product=gamma_product, mask=op_mask, input_dim=0, output_dim=None), verify_shape=True) else: use_gamma = False reduced_gamma_init = 'ones' moving_mean_product = op_tensor_tuple[0].get_param_product('moving_mean') reduced_mov_mean_init = tf.constant_initializer(_get_reduced_params(sess=sess, product=moving_mean_product, mask=op_mask, input_dim=0, output_dim=None), verify_shape=True) moving_variance_product = op_tensor_tuple[0].get_param_product('moving_variance') reduced_mov_variance_init = tf.constant_initializer(_get_reduced_params(sess=sess, product=moving_variance_product, mask=op_mask, input_dim=0, output_dim=None), verify_shape=True) name = "reduced_" + op_tensor_tuple[0].dotted_name # Get training attribute # This will either be True, False, or a string representing a training_placeholder the original BN was using training = BNUtils.get_training(op_tensor_tuple[0].get_module()) assert training is not None is_fused = op_tensor_tuple[0].type == 'FusedBatchNormV3' epsilon = BNUtils.get_epsilon(op_tensor_tuple[0].get_module()) momentum = BNUtils.get_momentum(op_tensor_tuple[0].get_module()) if momentum is not None: new_tensor = tf.keras.layers.BatchNormalization(center=use_beta, scale=use_gamma, epsilon=epsilon, momentum=momentum, beta_initializer=reduced_beta_init, gamma_initializer=reduced_gamma_init, moving_mean_initializer=reduced_mov_mean_init, moving_variance_initializer=reduced_mov_variance_init, fused=is_fused, name=name)(op_tensor_tuple[1][0], training=training) else: new_tensor = tf.keras.layers.BatchNormalization(center=use_beta, scale=use_gamma, epsilon=epsilon, beta_initializer=reduced_beta_init, gamma_initializer=reduced_gamma_init, moving_mean_initializer=reduced_mov_mean_init, moving_variance_initializer=reduced_mov_variance_init, fused=is_fused, name=name)(op_tensor_tuple[1][0], training=training) module = new_tensor.op.inputs[0].op return name, new_tensor.op, module