def compute_output(self, network, in_vw): deterministic = network.find_hyperparameter(["deterministic"]) moving_var_type = network.find_hyperparameter(["moving_var_type"], DEFAULT_MOVING_VAR_TYPE) epsilon = network.find_hyperparameter(["epsilon"], 1e-8) if moving_var_type == "log_var": moving_var_init_value = 1.0 def transform_var(v): return T.log(v + epsilon) def untransform_var(v): return T.exp(v) elif moving_var_type == "var": moving_var_init_value = 0.0 def transform_var(v): return v def untransform_var(v): return v elif moving_var_type == "inv_std": moving_var_init_value = 0.0 def transform_var(v): return T.inv(T.sqrt(v) + epsilon) def untransform_var(v): return T.sqr(T.inv(v)) # ----------------------------------------------- # calculate axes to have parameters/normalization # ----------------------------------------------- # axes over which there are parameters for each element # ie. parameter_axes == [1, 2] means shape[1] * shape[2] total # parameters - one for each combination of shape[1] and shape[2] parameter_axes = treeano.utils.find_axes( network, in_vw.ndim, positive_keys=["parameter_axes"], negative_keys=["non_parameter_axes"]) parameter_broadcastable = tuple( [idx not in parameter_axes for idx in range(in_vw.ndim)]) parameter_shape = tuple([ 1 if b else s for b, s in zip(parameter_broadcastable, in_vw.shape) ]) # axes to normalize over - ie. subtract the mean across these axes normalization_axes = treeano.utils.find_axes( network, in_vw.ndim, positive_keys=["normalization_axes"], negative_keys=["non_normalization_axes"]) stats_shape = tuple([ 1 if idx in normalization_axes else s for idx, s in enumerate(in_vw.shape) ]) stats_broadcastable = tuple( [idx in normalization_axes for idx in range(in_vw.ndim)]) assert all([s is not None for s in stats_shape]) # ----------------------- # initialize shared state # ----------------------- _gamma = network.create_vw( name="gamma", is_shared=True, shape=parameter_shape, tags={"parameter", "weight"}, # TODO try uniform init between -0.05 and 0.05 default_inits=[], default_inits_hyperparameters=["gamma_inits", "inits"], ) _beta = network.create_vw( name="beta", is_shared=True, shape=parameter_shape, tags={"parameter", "bias"}, default_inits=[], default_inits_hyperparameters=["beta_inits", "inits"], ) gamma = T.patternbroadcast(_gamma.variable, parameter_broadcastable) beta = T.patternbroadcast(_beta.variable, parameter_broadcastable) moving_mean = network.create_vw( name="mean", is_shared=True, shape=stats_shape, tags={"state"}, default_inits=[], ) moving_var = network.create_vw( name="var", is_shared=True, shape=stats_shape, tags={"state"}, default_inits=[treeano.inits.ConstantInit(moving_var_init_value)], ) # ------------------------ # calculate input mean/var # ------------------------ in_mean = T.mean(in_vw.variable, axis=normalization_axes, keepdims=True) biased_in_var = T.var(in_vw.variable, axis=normalization_axes, keepdims=True) batch_axis = network.find_hyperparameter(["batch_axis"]) if batch_axis is None: in_var = biased_in_var else: batch_size = in_vw.shape[batch_axis] if batch_size is None: batch_size = in_vw.variable.shape[batch_axis] else: batch_size = np.array(batch_size) batch_size = batch_size.astype(fX) unbias_factor = treeano.utils.as_fX(batch_size / (batch_size - 1)) in_var = unbias_factor * biased_in_var # save the mean/var for updating and debugging network.create_vw( name="in_mean", variable=in_mean, tags={}, shape=stats_shape, ) network.create_vw( name="in_var", variable=in_var, tags={}, shape=stats_shape, ) # ---------------- # calculate output # ---------------- bn_use_moving_stats = network.find_hyperparameter( ["bn_use_moving_stats"], False) if bn_use_moving_stats: effective_mean = T.patternbroadcast(moving_mean.variable, stats_broadcastable) effective_var = T.patternbroadcast( untransform_var(moving_var.variable), stats_broadcastable) else: if deterministic: msg = ("Batch normalization does not use `deterministic` flag" " to control whether or not moving stats are used for" " computation. In this case `bn_use_moving_stats` is " "False, thus per-minibatch stats will be used and may" "be stochastic (depending on how minibatches are" "created), and not only a function of the input" "observation") warnings.warn(msg) effective_mean = in_mean effective_var = in_var if network.find_hyperparameter(["consider_mean_constant"], False): effective_mean = T.consider_constant(effective_mean) if network.find_hyperparameter(["consider_var_constant"], False): effective_var = T.consider_constant(effective_var) epsilon = network.find_hyperparameter(["epsilon"], 1e-8) denom = T.sqrt(effective_var + epsilon) scaled = (in_vw.variable - effective_mean) / denom output = (1 + gamma) * scaled + beta network.create_vw( name="default", variable=output, shape=in_vw.shape, tags={"output"}, )
def compute_output(self, network, in_vw): alpha = network.find_hyperparameter(["bttf_alpha", "alpha"], 0.95) epsilon = network.find_hyperparameter(["epsilon"], 1e-4) normalization_axes = network.find_hyperparameter(["normalization_axes"], (1,)) # HACK: using "deterministic" to mean test time deterministic = network.find_hyperparameter(["deterministic"], False) update_averages = network.find_hyperparameter(["update_averages"], not deterministic) alpha = treeano.utils.as_fX(alpha) if update_averages: backprop_to_the_future_mean = bttf_mean.backprop_to_the_future_mean_with_updates else: backprop_to_the_future_mean = bttf_mean.backprop_to_the_future_mean_no_updates state_shape = tuple([in_vw.shape[axis] for axis in normalization_axes]) state_pattern = ["x"] * in_vw.ndim for idx, axis in enumerate(normalization_axes): state_pattern[axis] = idx def make_state(name, tags, default_inits=None): if default_inits is None: default_inits = [] return network.create_vw( name=name, is_shared=True, shape=state_shape, tags=tags, default_inits=default_inits, ).variable gamma = make_state("gamma", {"parameter", "weight"}) beta = make_state("beta", {"parameter", "bias"}) # mean of input mean = make_state("mean", {"state"}) # gradient of mean of input mean_grad = make_state("mean_grad", {"state"}) # mean of input^2 squared_mean = make_state("squared_mean", {"state"}, # initializing to 1, so that std = 1 default_inits=[treeano.inits.ConstantInit(1.)]) # gradient of mean of input^2 squared_mean_grad = make_state("squared_mean_grad", {"state"}) in_var = in_vw.variable mean_axes = tuple([axis for axis in range(in_var.ndim) if axis not in normalization_axes]) batch_mean = in_var.mean(axis=mean_axes) squared_batch_mean = T.sqr(in_var).mean(axis=mean_axes) # expectation of input (x) E_x = backprop_to_the_future_mean(batch_mean, mean, mean_grad, alpha) # TODO try mixing batch mean with E_x # expectation of input squared E_x_squared = backprop_to_the_future_mean(squared_batch_mean, squared_mean, squared_mean_grad, alpha) # HACK mixing batch and rolling means # E_x = 0.5 * E_x + 0.5 * batch_mean # E_x_squared = 0.5 * E_x_squared + 0.5 * squared_batch_mean if 1: mu = E_x sigma = T.sqrt(E_x_squared - T.sqr(E_x) + epsilon) mu = mu.dimshuffle(state_pattern) sigma = sigma.dimshuffle(state_pattern) gamma = gamma.dimshuffle(state_pattern) beta = beta.dimshuffle(state_pattern) else: # HACK mixing current value E_x = E_x.dimshuffle(state_pattern) E_x_squared = E_x_squared.dimshuffle(state_pattern) gamma = gamma.dimshuffle(state_pattern) beta = beta.dimshuffle(state_pattern) E_x = 0.1 * in_var + 0.9 * E_x E_x_squared = 0.1 * T.sqr(in_var) + 0.9 * E_x_squared mu = E_x sigma = T.sqrt(E_x_squared - T.sqr(E_x) + epsilon) if 0: # HACK don't backprop through sigma sigma = T.consider_constant(sigma) if 1: # HACK using batch mean mu = batch_mean mu = mu.dimshuffle(state_pattern) if 0: # HACK using batch variance sigma = T.sqrt(in_var.var(axis=mean_axes) + epsilon) sigma = sigma.dimshuffle(state_pattern) out_var = (in_var - mu) * (T.exp(gamma) / sigma) + beta network.create_vw( name="default", variable=out_var, shape=in_vw.shape, tags={"output"}, ) if 1: # HACK monitoring state network.create_vw( name="mu_mean", variable=mu.mean(), shape=(), tags={"monitor"}, ) network.create_vw( name="sigma_mean", variable=sigma.mean(), shape=(), tags={"monitor"}, ) network.create_vw( name="gamma_mean", variable=gamma.mean(), shape=(), tags={"monitor"}, ) network.create_vw( name="beta_mean", variable=beta.mean(), shape=(), tags={"monitor"}, )
def compute_output(self, network, in_vw): deterministic = network.find_hyperparameter(["deterministic"]) use_log_moving_var = network.find_hyperparameter( ["use_log_moving_var"], DEFAULT_USE_LOG_MOVING_VAR) if use_log_moving_var: def transform_var(v): epsilon = network.find_hyperparameter(["epsilon"], 1e-8) return T.log(v + epsilon) def untransform_var(v): return T.exp(v) else: def transform_var(v): return v def untransform_var(v): return v # ----------------------------------------------- # calculate axes to have parameters/normalization # ----------------------------------------------- # axes over which there are parameters for each element # ie. parameter_axes == [1, 2] means shape[1] * shape[2] total # parameters - one for each combination of shape[1] and shape[2] parameter_axes = treeano.utils.find_axes( network, in_vw.ndim, positive_keys=["parameter_axes"], negative_keys=["non_parameter_axes"]) parameter_broadcastable = tuple([idx not in parameter_axes for idx in range(in_vw.ndim)]) parameter_shape = tuple([1 if b else s for b, s in zip(parameter_broadcastable, in_vw.shape)]) # axes to normalize over - ie. subtract the mean across these axes normalization_axes = treeano.utils.find_axes( network, in_vw.ndim, positive_keys=["normalization_axes"], negative_keys=["non_normalization_axes"]) stats_shape = tuple([1 if idx in normalization_axes else s for idx, s in enumerate(in_vw.shape)]) stats_broadcastable = tuple([idx in normalization_axes for idx in range(in_vw.ndim)]) assert all([s is not None for s in stats_shape]) # ----------------------- # initialize shared state # ----------------------- gamma_inits = list(toolz.concat(network.find_hyperparameters( ["gamma_inits", "inits"], # TODO uniform init between 0.95 and 1.05 [treeano.inits.ConstantInit(1.0)]))) beta_inits = list(toolz.concat(network.find_hyperparameters( ["beta_inits", "inits"], [treeano.inits.ConstantInit(0.0)]))) mean_inits = list(toolz.concat(network.find_hyperparameters( ["inits"], []))) var_inits = list(toolz.concat(network.find_hyperparameters( ["inits"], [treeano.inits.ConstantInit(1.0 if use_log_moving_var else 0.0)]))) _gamma = network.create_vw( name="gamma", is_shared=True, shape=parameter_shape, tags={"parameter"}, inits=gamma_inits, ) _beta = network.create_vw( name="beta", is_shared=True, shape=parameter_shape, tags={"parameter"}, inits=beta_inits, ) gamma = T.patternbroadcast(_gamma.variable, parameter_broadcastable) beta = T.patternbroadcast(_beta.variable, parameter_broadcastable) moving_mean = network.create_vw( name="mean", is_shared=True, shape=stats_shape, tags={"state"}, inits=mean_inits, ) moving_var = network.create_vw( name="var", is_shared=True, shape=stats_shape, tags={"state"}, inits=var_inits, ) # ------------------------ # calculate input mean/var # ------------------------ in_mean = T.mean(in_vw.variable, axis=normalization_axes, keepdims=True) biased_in_var = T.var(in_vw.variable, axis=normalization_axes, keepdims=True) batch_axis = network.find_hyperparameter(["batch_axis"]) if batch_axis is None: in_var = biased_in_var else: batch_size = in_vw.shape[batch_axis] if batch_size is None: batch_size = in_vw.variable.shape[batch_axis] else: batch_size = np.array(batch_size) batch_size = batch_size.astype(fX) unbias_factor = treeano.utils.as_fX(batch_size / (batch_size - 1)) in_var = unbias_factor * biased_in_var assert in_mean.broadcastable == stats_broadcastable assert in_var.broadcastable == stats_broadcastable # save the mean/var for updating and debugging network.create_vw( name="in_mean", variable=in_mean, tags={}, shape=stats_shape, ) network.create_vw( name="in_var", variable=in_var, tags={}, shape=stats_shape, ) # ---------------- # calculate output # ---------------- bn_use_moving_stats = network.find_hyperparameter( ["bn_use_moving_stats"], False) if bn_use_moving_stats: effective_mean = T.patternbroadcast(moving_mean.variable, stats_broadcastable) effective_var = T.patternbroadcast( untransform_var(moving_var.variable), stats_broadcastable) else: if deterministic: msg = ("Batch normalization does not use `deterministic` flag" " to control whether or not moving stats are used for" " computation. In this case `bn_use_moving_stats` is " "False, thus per-minibatch stats will be used and may" "be stochastic (depending on how minibatches are" "created), and not only a function of the input" "observation") warnings.warn(msg) effective_mean = in_mean effective_var = in_var if network.find_hyperparameter(["consider_mean_constant"], False): effective_mean = T.consider_constant(effective_mean) if network.find_hyperparameter(["consider_var_constant"], False): effective_var = T.consider_constant(effective_var) epsilon = network.find_hyperparameter(["epsilon"], 1e-8) denom = T.sqrt(effective_var + epsilon) scaled = (in_vw.variable - effective_mean) / denom output = gamma * scaled + beta network.create_vw( name="default", variable=output, shape=in_vw.shape, tags={"output"}, )
def compute_output(self, network, in_vw): alpha = network.find_hyperparameter(["bttf_alpha", "alpha"], 0.95) epsilon = network.find_hyperparameter(["epsilon"], 1e-4) normalization_axes = network.find_hyperparameter( ["normalization_axes"], (1, )) # HACK: using "deterministic" to mean test time deterministic = network.find_hyperparameter(["deterministic"], False) update_averages = network.find_hyperparameter(["update_averages"], not deterministic) alpha = treeano.utils.as_fX(alpha) if update_averages: backprop_to_the_future_mean = bttf_mean.backprop_to_the_future_mean_with_updates else: backprop_to_the_future_mean = bttf_mean.backprop_to_the_future_mean_no_updates state_shape = tuple([in_vw.shape[axis] for axis in normalization_axes]) state_pattern = ["x"] * in_vw.ndim for idx, axis in enumerate(normalization_axes): state_pattern[axis] = idx def make_state(name, tags, default_inits=None): if default_inits is None: default_inits = [] return network.create_vw( name=name, is_shared=True, shape=state_shape, tags=tags, default_inits=default_inits, ).variable gamma = make_state("gamma", {"parameter", "weight"}) beta = make_state("beta", {"parameter", "bias"}) # mean of input mean = make_state("mean", {"state"}) # gradient of mean of input mean_grad = make_state("mean_grad", {"state"}) # mean of input^2 squared_mean = make_state( "squared_mean", {"state"}, # initializing to 1, so that std = 1 default_inits=[treeano.inits.ConstantInit(1.)]) # gradient of mean of input^2 squared_mean_grad = make_state("squared_mean_grad", {"state"}) in_var = in_vw.variable mean_axes = tuple([ axis for axis in range(in_var.ndim) if axis not in normalization_axes ]) batch_mean = in_var.mean(axis=mean_axes) squared_batch_mean = T.sqr(in_var).mean(axis=mean_axes) # expectation of input (x) E_x = backprop_to_the_future_mean(batch_mean, mean, mean_grad, alpha) # TODO try mixing batch mean with E_x # expectation of input squared E_x_squared = backprop_to_the_future_mean(squared_batch_mean, squared_mean, squared_mean_grad, alpha) # HACK mixing batch and rolling means # E_x = 0.5 * E_x + 0.5 * batch_mean # E_x_squared = 0.5 * E_x_squared + 0.5 * squared_batch_mean if 1: mu = E_x sigma = T.sqrt(E_x_squared - T.sqr(E_x) + epsilon) mu = mu.dimshuffle(state_pattern) sigma = sigma.dimshuffle(state_pattern) gamma = gamma.dimshuffle(state_pattern) beta = beta.dimshuffle(state_pattern) else: # HACK mixing current value E_x = E_x.dimshuffle(state_pattern) E_x_squared = E_x_squared.dimshuffle(state_pattern) gamma = gamma.dimshuffle(state_pattern) beta = beta.dimshuffle(state_pattern) E_x = 0.1 * in_var + 0.9 * E_x E_x_squared = 0.1 * T.sqr(in_var) + 0.9 * E_x_squared mu = E_x sigma = T.sqrt(E_x_squared - T.sqr(E_x) + epsilon) if 0: # HACK don't backprop through sigma sigma = T.consider_constant(sigma) if 1: # HACK using batch mean mu = batch_mean mu = mu.dimshuffle(state_pattern) if 0: # HACK using batch variance sigma = T.sqrt(in_var.var(axis=mean_axes) + epsilon) sigma = sigma.dimshuffle(state_pattern) out_var = (in_var - mu) * (T.exp(gamma) / sigma) + beta network.create_vw( name="default", variable=out_var, shape=in_vw.shape, tags={"output"}, ) if 1: # HACK monitoring state network.create_vw( name="mu_mean", variable=mu.mean(), shape=(), tags={"monitor"}, ) network.create_vw( name="sigma_mean", variable=sigma.mean(), shape=(), tags={"monitor"}, ) network.create_vw( name="gamma_mean", variable=gamma.mean(), shape=(), tags={"monitor"}, ) network.create_vw( name="beta_mean", variable=beta.mean(), shape=(), tags={"monitor"}, )