def local_abstract_batch_norm_inference(fgraph, node): if not isinstance(node.op, AbstractBatchNormInference): return None x, scale, bias, estimated_mean, estimated_variance, epsilon = node.inputs if ( not isinstance(x.type, TensorType) or not isinstance(scale.type, TensorType) or not isinstance(bias.type, TensorType) or not isinstance(estimated_mean.type, TensorType) or not isinstance(estimated_variance.type, TensorType) or not isinstance(epsilon.type, TensorType) ): return None # The epsilon should not upcast the dtype. if estimated_variance.dtype == "float32" and epsilon.dtype == "float64": epsilon = epsilon.astype("float32") result = (x - estimated_mean) * (scale / sqrt(estimated_variance + epsilon)) + bias result = aet.patternbroadcast(result, node.outputs[0].broadcastable) for var in aesara.graph.basic.vars_between(node.inputs, [result]): if var not in node.inputs: copy_stack_trace(node.outputs[0], var) return [result]
def local_abstract_batch_norm_train_grad(fgraph, node): if not isinstance(node.op, AbstractBatchNormTrainGrad): return None x, dy, scale, x_mean, x_invstd, epsilon = node.inputs axes = node.op.axes if min(axes) < 0 or max(axes) > x.ndim: return None if (not isinstance(x.type, TensorType) or not isinstance(dy.type, TensorType) or not isinstance(scale.type, TensorType) or not isinstance(x_mean.type, TensorType) or not isinstance(x_invstd.type, TensorType) or not isinstance(epsilon.type, TensorType)): return None x_diff = x - x_mean mean_dy_x_diff = mean(dy * x_diff, axis=axes, keepdims=True) c = (dy * x_invstd) - x_diff * (mean_dy_x_diff * (x_invstd**3)) g_wrt_inputs = scale * (c - mean(c, axis=axes, keepdims=True)) g_wrt_scale = aet_sum(dy * x_invstd * x_diff, axis=axes, keepdims=True) g_wrt_bias = aet_sum(dy, axis=axes, keepdims=True) results = [g_wrt_inputs, g_wrt_scale, g_wrt_bias] results = [ aet.patternbroadcast(r, r_orig.broadcastable) for (r, r_orig) in zip(results, node.outputs) ] for var in aesara.graph.basic.vars_between(node.inputs, results): if var not in node.inputs: copy_stack_trace(node.outputs[0], var) return results
def local_abstract_batch_norm_train(fgraph, node): if not isinstance(node.op, AbstractBatchNormTrain): return None x, scale, bias, epsilon, running_average_factor = node.inputs[:5] axes = node.op.axes if min(axes) < 0 or max(axes) > x.ndim: return None if ( not isinstance(x.type, TensorType) or not isinstance(scale.type, TensorType) or not isinstance(bias.type, TensorType) or not isinstance(epsilon.type, TensorType) or not isinstance(running_average_factor.type, TensorType) ): return None # optional running_mean and running_var if len(node.inputs) > 5 and not isinstance(node.inputs[5].type, TensorType): return None if len(node.inputs) > 6 and not isinstance(node.inputs[6].type, TensorType): return None mean = x.mean(axes, keepdims=True) var = x.var(axes, keepdims=True) # The epsilon should not upcast the dtype. if var.dtype == "float32" and epsilon.dtype == "float64": epsilon = epsilon.astype("float32") invstd = inv(sqrt(var + epsilon)) out = (x - mean) * (scale * invstd) + bias results = [out, mean, invstd] if len(node.inputs) > 5: running_mean = node.inputs[5] running_mean = ( running_mean * (1.0 - running_average_factor) + mean * running_average_factor ) results.append(running_mean) if len(node.inputs) > 6: m = aet.cast(prod(x.shape) / prod(scale.shape), config.floatX) running_var = node.inputs[6] running_var = ( running_var * (1.0 - running_average_factor) + (m / (m - 1)) * var * running_average_factor ) results.append(running_var) results = [ aet.patternbroadcast(r, r_orig.broadcastable) for (r, r_orig) in zip(results, node.outputs) ] for var in aesara.graph.basic.vars_between(node.inputs, results): if var not in node.inputs: copy_stack_trace(node.outputs[0], var) return results
def batch_normalization_train( inputs, gamma, beta, axes="per-activation", epsilon=1e-4, running_average_factor=0.1, running_mean=None, running_var=None, ): """ Performs batch normalization of the given inputs, using the mean and variance of the inputs. Parameters ---------- axes : 'per-activation', 'spatial' or a tuple of ints The axes along which the input should be normalized. ``'per-activation'`` normalizes per activation and is equal to ``axes=(0,)``. ``'spatial'`` shares normalization factors across spatial dimensions (i.e., all dimensions past the second), which for 4D inputs would be equal to ``axes=(0, 2, 3)``. gamma : tensor Learnable scale factors. The shape must match the shape of `inputs`, except for the axes in `axes`. These axes should be set to 1 or be skipped altogether (such that `gamma.ndim == inputs.ndim - len(axes)`). beta : tensor Learnable biases. Must match the tensor layout of `gamma`. epsilon : float Epsilon value used in the batch normalization formula. Minimum allowed value is 1e-5 (imposed by cuDNN). running_average_factor : float Factor for updating the values or `running_mean` and `running_var`. If the factor is close to one, the running averages will update quickly, if the factor is close to zero it will update slowly. running_mean : tensor or None Previous value of the running mean. If this is given, the new value ``running_mean * (1 - r_a_factor) + batch mean * r_a_factor`` will be returned as one of the outputs of this function. `running_mean` and `running_var` should either both be given or both be None. The shape should match that of `gamma` and `beta`. running_var : tensor or None Previous value of the running variance. If this is given, the new value ``running_var * (1 - r_a_factor) + (m / (m - 1)) * batch var * r_a_factor`` will be returned as one of the outputs of this function, where `m` is the product of lengths of the averaged-over dimensions. `running_mean` and `running_var` should either both be given or both be None. The shape should match that of `gamma` and `beta`. Returns ------- out : tensor Batch-normalized inputs. mean : tensor Means of `inputs` across the normalization axes. invstd : tensor Inverse standard deviations of `inputs` across the normalization axes. new_running_mean : tensor New value of the running mean (only if both `running_mean` and `running_var` were given). new_running_var : tensor New value of the running variance (only if both `running_var` and `running_mean` were given). Notes ----- If per-activation or spatial normalization is selected, this operation will use the cuDNN implementation. (This requires cuDNN 5 or newer.) The returned values are equivalent to: .. code-block:: python # for per-activation normalization axes = (0,) # for spatial normalization axes = (0,) + tuple(range(2, inputs.ndim)) mean = inputs.mean(axes, keepdims=True) var = inputs.var(axes, keepdims=True) invstd = T.inv(T.sqrt(var + epsilon)) out = (inputs - mean) * gamma * invstd + beta m = T.cast(T.prod(inputs.shape) / T.prod(mean.shape), 'float32') running_mean = running_mean * (1 - running_average_factor) + \\ mean * running_average_factor running_var = running_var * (1 - running_average_factor) + \\ (m / (m - 1)) * var * running_average_factor """ ndim = inputs.ndim axes, non_bc_axes = _prepare_batch_normalization_axes(axes, ndim) # have the parameter tensors been broadcasted yet? if gamma.ndim == ndim: params_ndim = ndim else: params_ndim = len(non_bc_axes) params_dimshuffle_pattern = ["x"] * ndim for i, axis in enumerate(non_bc_axes): params_dimshuffle_pattern[axis] = i if gamma.ndim != params_ndim or beta.ndim != params_ndim: raise ValueError( "gamma and beta dimensionality must match the " "number of non-normalized axes, or have the " "same number of dimensions as the inputs; " f"got {int(gamma.ndim)} and {int(beta.ndim)} instead of {int(params_ndim)}" ) if (running_mean is None) != (running_var is None): raise ValueError( "running_mean and running_var must either both be " "given or both be None" ) if running_mean is not None and running_mean.ndim != params_ndim: raise ValueError( "running_mean must be of the same dimensionality " f"as gamma and beta; got {int(running_mean.ndim)} instead of {int(params_ndim)}" ) if running_var is not None and running_var.ndim != params_ndim: raise ValueError( "running_var must be of the same dimensionality " f"as gamma and beta; got {int(running_var.ndim)} instead of {int(params_ndim)}" ) # epsilon will be converted to floatX later. we need to check # for rounding errors now, since numpy.float32(1e-5) < 1e-5. epsilon = np.cast[config.floatX](epsilon) if epsilon < 1e-5: raise ValueError(f"epsilon must be at least 1e-5, got {epsilon}") inputs = as_tensor_variable(inputs) gamma = as_tensor_variable(gamma) beta = as_tensor_variable(beta) if params_ndim != ndim: gamma = gamma.dimshuffle(params_dimshuffle_pattern) beta = beta.dimshuffle(params_dimshuffle_pattern) else: gamma = aet.addbroadcast(gamma, *axes) beta = aet.addbroadcast(beta, *axes) batchnorm_op = AbstractBatchNormTrain(axes=axes) if running_mean is not None and running_var is not None: running_mean = as_tensor_variable(running_mean) running_var = as_tensor_variable(running_var) if params_ndim != ndim: running_mean = running_mean.dimshuffle(params_dimshuffle_pattern) running_var = running_var.dimshuffle(params_dimshuffle_pattern) else: running_mean = aet.addbroadcast(running_mean, *axes) running_var = aet.addbroadcast(running_var, *axes) out, mean, invstd, new_running_mean, new_running_var = batchnorm_op( inputs, gamma, beta, epsilon=epsilon, running_average_factor=running_average_factor, running_mean=running_mean, running_var=running_var, ) if new_running_mean.broadcastable != running_mean.broadcastable: new_running_mean = aet.patternbroadcast( new_running_mean, running_mean.broadcastable ) if new_running_var.broadcastable != running_var.broadcastable: new_running_var = aet.patternbroadcast( new_running_var, running_var.broadcastable ) results = (out, mean, invstd, new_running_mean, new_running_var) else: results = batchnorm_op(inputs, gamma, beta, epsilon=epsilon) if params_ndim != ndim: # remove the broadcasted dimensions (except from the output) results = [results[0]] + [r.dimshuffle(non_bc_axes) for r in results[1:]] return tuple(results)