self.gamma, epsilon, scale_after_normalization=True) else: normalized_x = tf.nn.batch_norm_with_global_normalization( input_layer.tensor, self.mean, self.variance, self.beta, self.gamma, epsilon, scale_after_normalization=True) return input_layer.with_tensor(normalized_x, parameters=self.vars) pt.Register(assign_defaults=('phase'))(conv_batch_norm) @pt.Register(assign_defaults=('phase')) class fc_batch_norm(conv_batch_norm): def __call__(self, input_layer, *args, **kwargs): ori_shape = input_layer.shape if ori_shape[0] is None: ori_shape[0] = -1 new_shape = [ori_shape[0], 1, 1, ori_shape[1]] x = tf.reshape(input_layer.tensor, new_shape) normalized_x = super(self.__class__, self).__call__(input_layer.with_tensor(x), *args, **kwargs) # input_layer) return normalized_x.reshape(ori_shape)
import prettytensor as pt import tensorflow as tf from prettytensor.pretty_tensor_class import Phase import numpy as np from tensorflow.python.ops import nn_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import array_ops from tensorflow.python.framework import ops ''' class conv_batch_norm(pt.VarStoreMethod): """Code modification of http://stackoverflow.com/a/33950177""" def __call__(self, input_layer, epsilon=1e-5, momentum=0.1, name="batch_norm", in_dim=None, phase=Phase.train): self.ema = tf.train.ExponentialMovingAverage(decay=0.9) shape = input_layer.shape shp = in_dim or shape[-1] with tf.variable_scope(name) as scope: self.gamma = self.variable("gamma", [shp], init=tf.random_normal_initializer(1., 0.02)) self.beta = self.variable("beta", [shp], init=tf.constant_initializer(0.)) self.mean, self.variance = tf.nn.moments(input_layer.tensor, [0, 1, 2]) # sigh...tf's shape system is so.. self.mean.set_shape((shp,)) self.variance.set_shape((shp,)) self.ema_apply_op = self.ema.apply([self.mean, self.variance]) if phase == Phase.train: with tf.control_dependencies([self.ema_apply_op]):