Esempio n. 1
0
    def grad(self, inputs, output_gradients):
        V, W, b, d = inputs
        dCdH, = output_gradients
        # make all of these ops support broadcasting of scalar b to vector b and eplace the zeros_like in all their grads
        # print dCdH.broadcastable
        # print "dCdH.broadcastable"
        # quit(-1)
        # dCdH = printing.Print("dCdH = ",["shape"])

        # Make sure the broadcasting pattern of the gradient is the the same
        # as the initial variable
        dCdV = theano.tensor.nnet.convTransp3D(W, T.zeros_like(V[0, 0, 0,
                                                                 0, :]), d,
                                               dCdH, V.shape[1:4])
        dCdV = T.patternbroadcast(dCdV, V.broadcastable)
        WShape = W.shape
        dCdW = theano.tensor.nnet.convGrad3D(V, d, WShape, dCdH)
        dCdW = T.patternbroadcast(dCdW, W.broadcastable)
        dCdb = T.sum(dCdH, axis=(0, 1, 2, 3))
        dCdb = T.patternbroadcast(dCdb, b.broadcastable)
        dCdd = grad_undefined(
            self, 3, inputs[3],
            "The gradient of Conv3D with respect to the convolution"
            " stride is undefined because Conv3D is only defined for"
            " integer strides.")

        if 'name' in dir(dCdH) and dCdH.name is not None:
            dCdH_name = dCdH.name
        else:
            dCdH_name = 'anon_dCdH'

        if 'name' in dir(V) and V.name is not None:
            V_name = V.name
        else:
            V_name = 'anon_V'

        if 'name' in dir(W) and W.name is not None:
            W_name = W.name
        else:
            W_name = 'anon_W'

        if 'name' in dir(b) and b.name is not None:
            b_name = b.name
        else:
            b_name = 'anon_b'

        dCdV.name = 'Conv3D_dCdV(dCdH=' + dCdH_name + ',V=' + V_name + ')'
        dCdW.name = ('Conv3D_dCdW(dCdH=' + dCdH_name + ',V=' + V_name + ',W=' +
                     W_name + ')')
        dCdb.name = ('Conv3D_dCdb(dCdH=' + dCdH_name + ',V=' + V_name + ',W=' +
                     W_name + ',b=' + b_name + ')')

        return [dCdV, dCdW, dCdb, dCdd]
Esempio n. 2
0
    def grad(self, inputs, output_gradients):
        V, W, b, d = inputs
        dCdH, = output_gradients
        # make all of these ops support broadcasting of scalar b to vector b and eplace the zeros_like in all their grads
        # print dCdH.broadcastable
        # print "dCdH.broadcastable"
        # quit(-1)
        # dCdH = printing.Print("dCdH = ",["shape"])

        # Make sure the broadcasting pattern of the gradient is the the same
        # as the initial variable
        dCdV = theano.tensor.nnet.convTransp3D(
            W, T.zeros_like(V[0, 0, 0, 0, :]), d, dCdH, V.shape[1:4])
        dCdV = T.patternbroadcast(dCdV, V.broadcastable)
        WShape = W.shape
        dCdW = theano.tensor.nnet.convGrad3D(V, d, WShape, dCdH)
        dCdW = T.patternbroadcast(dCdW, W.broadcastable)
        dCdb = T.sum(dCdH, axis=(0, 1, 2, 3))
        dCdb = T.patternbroadcast(dCdb, b.broadcastable)
        dCdd = grad_undefined(
            self, 3, inputs[3],
            "The gradient of Conv3D with respect to the convolution"
            " stride is undefined because Conv3D is only defined for"
            " integer strides.")

        if 'name' in dir(dCdH) and dCdH.name is not None:
            dCdH_name = dCdH.name
        else:
            dCdH_name = 'anon_dCdH'

        if 'name' in dir(V) and V.name is not None:
            V_name = V.name
        else:
            V_name = 'anon_V'

        if 'name' in dir(W) and W.name is not None:
            W_name = W.name
        else:
            W_name = 'anon_W'

        if 'name' in dir(b) and b.name is not None:
            b_name = b.name
        else:
            b_name = 'anon_b'

        dCdV.name = 'Conv3D_dCdV(dCdH=' + dCdH_name + ',V=' + V_name + ')'
        dCdW.name = ('Conv3D_dCdW(dCdH=' + dCdH_name + ',V=' + V_name +
                     ',W=' + W_name + ')')
        dCdb.name = ('Conv3D_dCdb(dCdH=' + dCdH_name + ',V=' + V_name +
                     ',W=' + W_name + ',b=' + b_name + ')')

        return [dCdV, dCdW, dCdb, dCdd]
Esempio n. 3
0
def local_abstract_batch_norm_inference(node):
    if not isinstance(node.op, AbstractBatchNormInference):
        return None

    x, scale, bias, estimated_mean, estimated_variance, epsilon = node.inputs

    if not isinstance(x.type, TensorType) or \
       not isinstance(scale.type, TensorType) or \
       not isinstance(bias.type, TensorType) or \
       not isinstance(estimated_mean.type, TensorType) or \
       not isinstance(estimated_variance.type, TensorType) or \
       not isinstance(epsilon.type, TensorType):
        return None

    # The epsilon should not upcast the dtype.
    if estimated_variance.dtype == 'float32' and epsilon.dtype == 'float64':
        epsilon = epsilon.astype('float32')

    result = (x - estimated_mean) * (scale / T.sqrt(estimated_variance + epsilon)) + bias
    result = T.patternbroadcast(result, node.outputs[0].broadcastable)

    for var in theano.gof.graph.variables(node.inputs, [result]):
        if var not in node.inputs:
            copy_stack_trace(node.outputs[0], var)
    return [result]
Esempio n. 4
0
def local_abstract_batch_norm_inference(fgraph, node):
    if not isinstance(node.op, AbstractBatchNormInference):
        return None

    x, scale, bias, estimated_mean, estimated_variance, epsilon = node.inputs

    if (
        not isinstance(x.type, TensorType)
        or not isinstance(scale.type, TensorType)
        or not isinstance(bias.type, TensorType)
        or not isinstance(estimated_mean.type, TensorType)
        or not isinstance(estimated_variance.type, TensorType)
        or not isinstance(epsilon.type, TensorType)
    ):
        return None

    # The epsilon should not upcast the dtype.
    if estimated_variance.dtype == "float32" and epsilon.dtype == "float64":
        epsilon = epsilon.astype("float32")

    result = (x - estimated_mean) * (
        scale / tt.sqrt(estimated_variance + epsilon)
    ) + bias
    result = tt.patternbroadcast(result, node.outputs[0].broadcastable)

    for var in theano.gof.graph.variables(node.inputs, [result]):
        if var not in node.inputs:
            copy_stack_trace(node.outputs[0], var)
    return [result]
Esempio n. 5
0
def local_abstract_batch_norm_train_grad(node):
    if not isinstance(node.op, AbstractBatchNormTrainGrad):
        return None

    x, dy, scale, x_mean, x_invstd, epsilon = node.inputs
    axes = node.op.axes
    if min(axes) < 0 or max(axes) > x.ndim:
        return None
    if not isinstance(x.type, TensorType) or \
       not isinstance(dy.type, TensorType) or \
       not isinstance(scale.type, TensorType) or \
       not isinstance(x_mean.type, TensorType) or \
       not isinstance(x_invstd.type, TensorType) or \
       not isinstance(epsilon.type, TensorType):
        return None

    x_diff = x - x_mean
    mean_dy_x_diff = T.mean(dy * x_diff, axis=axes, keepdims=True)
    c = (dy * x_invstd) - x_diff * (mean_dy_x_diff * (x_invstd ** 3))

    g_wrt_inputs = scale * (c - T.mean(c, axis=axes, keepdims=True))
    g_wrt_scale = T.sum(dy * x_invstd * x_diff, axis=axes, keepdims=True)
    g_wrt_bias = T.sum(dy, axis=axes, keepdims=True)
    results = [g_wrt_inputs, g_wrt_scale, g_wrt_bias]

    results = [T.patternbroadcast(r, r_orig.broadcastable)
               for (r, r_orig) in zip(results, node.outputs)]

    for var in theano.gof.graph.variables(node.inputs, results):
        if var not in node.inputs:
            copy_stack_trace(node.outputs[0], var)
    return results
Esempio n. 6
0
    def grad(self,inputs, output_gradients):
        V,W,b,d = inputs
        dCdH ,= output_gradients
        #make all of these ops support broadcasting of scalar b to vector b and eplace the zeros_like in all their grads
        #print dCdH.broadcastable
        #print "dCdH.broadcastable"
        #quit(-1)
        #dCdH = printing.Print("dCdH = ",["shape"])

        # Make sure the broadcasting pattern of the gradient is the the same
        # as the initial variable
        dCdV = ConvTransp3D.convTransp3D(W, T.zeros_like(V[0,0,0,0,:]), d, dCdH, V.shape[1:4])
        dCdV = T.patternbroadcast(dCdV, V.broadcastable)
        WShape = W.shape
        dCdW = ConvGrad3D.convGrad3D(V,d,WShape,dCdH)
        dCdW = T.patternbroadcast(dCdW, W.broadcastable)
        dCdb = T.sum(dCdH, axis=(0,1,2,3))
        dCdb = T.patternbroadcast(dCdb, b.broadcastable)
        dCdd = None #not differentiable, since d is not continuous

        if 'name' in dir(dCdH) and dCdH.name is not None:
            dCdH_name = dCdH.name
        else:
            dCdH_name = 'anon'

        if 'name' in dir(V) and V.name is not None:
            V_name = V.name
        else:
            V_name = 'anon'

        if 'name' in dir(W) and W.name is not None:
            W_name = W.name
        else:
            W_name = 'anon'

        if 'name' in dir(b) and b.name is not None:
            b_name = b.name
        else:
            b_name = 'anon'

        dCdV.name = 'Conv3D_dCdV.dCdH='+dCdH_name+',V='+V_name
        dCdW.name = 'Conv3D_dCdW.dCdH='+dCdH_name+',V='+V_name+',W='+W_name
        dCdb.name = 'Conv3D_dCdb.dCdH='+dCdH_name+',V='+V_name+',W='+W_name+',b='+b_name



        return [ dCdV, dCdW, dCdb, dCdd ]
Esempio n. 7
0
def local_abstract_batch_norm_train(fgraph, node):
    if not isinstance(node.op, AbstractBatchNormTrain):
        return None

    x, scale, bias, epsilon, running_average_factor = node.inputs[:5]
    axes = node.op.axes
    if min(axes) < 0 or max(axes) > x.ndim:
        return None
    if (
        not isinstance(x.type, TensorType)
        or not isinstance(scale.type, TensorType)
        or not isinstance(bias.type, TensorType)
        or not isinstance(epsilon.type, TensorType)
        or not isinstance(running_average_factor.type, TensorType)
    ):
        return None
    # optional running_mean and running_var
    if len(node.inputs) > 5 and not isinstance(node.inputs[5].type, TensorType):
        return None
    if len(node.inputs) > 6 and not isinstance(node.inputs[6].type, TensorType):
        return None

    mean = x.mean(axes, keepdims=True)
    var = x.var(axes, keepdims=True)
    # The epsilon should not upcast the dtype.
    if var.dtype == "float32" and epsilon.dtype == "float64":
        epsilon = epsilon.astype("float32")
    invstd = tt.inv(tt.sqrt(var + epsilon))
    out = (x - mean) * (scale * invstd) + bias
    results = [out, mean, invstd]

    if len(node.inputs) > 5:
        running_mean = node.inputs[5]
        running_mean = (
            running_mean * (1.0 - running_average_factor)
            + mean * running_average_factor
        )
        results.append(running_mean)
    if len(node.inputs) > 6:
        m = tt.cast(tt.prod(x.shape) / tt.prod(scale.shape), theano.config.floatX)
        running_var = node.inputs[6]
        running_var = (
            running_var * (1.0 - running_average_factor)
            + (m / (m - 1)) * var * running_average_factor
        )
        results.append(running_var)

    results = [
        tt.patternbroadcast(r, r_orig.broadcastable)
        for (r, r_orig) in zip(results, node.outputs)
    ]

    for var in theano.gof.graph.variables(node.inputs, results):
        if var not in node.inputs:
            copy_stack_trace(node.outputs[0], var)
    return results
Esempio n. 8
0
def local_abstract_batch_norm_train(node):
    if not isinstance(node.op, AbstractBatchNormTrain):
        return None

    x, scale, bias, epsilon, running_average_factor = node.inputs[:5]
    axes = node.op.axes
    if min(axes) < 0 or max(axes) > x.ndim:
        return None
    if not isinstance(x.type, TensorType) or \
       not isinstance(scale.type, TensorType) or \
       not isinstance(bias.type, TensorType) or \
       not isinstance(epsilon.type, TensorType) or \
       not isinstance(running_average_factor.type, TensorType):
        return None
    # optional running_mean and running_var
    if len(node.inputs) > 5 and not isinstance(node.inputs[5].type, TensorType):
        return None
    if len(node.inputs) > 6 and not isinstance(node.inputs[6].type, TensorType):
        return None

    mean = x.mean(axes, keepdims=True)
    var = x.var(axes, keepdims=True)
    # The epsilon should not upcast the dtype.
    if var.dtype == 'float32' and epsilon.dtype == 'float64':
        epsilon = epsilon.astype('float32')
    invstd = T.inv(T.sqrt(var + epsilon))
    out = (x - mean) * (scale * invstd) + bias
    results = [out, mean, invstd]

    if len(node.inputs) > 5:
        running_mean = node.inputs[5]
        running_mean = running_mean * (1.0 - running_average_factor) + \
            mean * running_average_factor
        results.append(running_mean)
    if len(node.inputs) > 6:
        m = T.cast(T.prod(x.shape) / T.prod(scale.shape), theano.config.floatX)
        running_var = node.inputs[6]
        running_var = running_var * (1.0 - running_average_factor) + \
            (m / (m - 1)) * var * running_average_factor
        results.append(running_var)

    results = [T.patternbroadcast(r, r_orig.broadcastable)
               for (r, r_orig) in zip(results, node.outputs)]

    for var in theano.gof.graph.variables(node.inputs, results):
        if var not in node.inputs:
            copy_stack_trace(node.outputs[0], var)
    return results
Esempio n. 9
0
def local_abstract_batch_norm_inference(node):
    if not isinstance(node.op, AbstractBatchNormInference):
        return None

    x, scale, bias, estimated_mean, estimated_variance, epsilon = node.inputs

    if not isinstance(x.type, TensorType) or \
       not isinstance(scale.type, TensorType) or \
       not isinstance(bias.type, TensorType) or \
       not isinstance(estimated_mean.type, TensorType) or \
       not isinstance(estimated_variance.type, TensorType) or \
       not isinstance(epsilon.type, TensorType):
        return None

    result = (x - estimated_mean) * (
        scale / T.sqrt(estimated_variance + epsilon)) + bias
    result = T.patternbroadcast(result, node.outputs[0].broadcastable)

    for var in theano.gof.graph.variables(node.inputs, [result]):
        if var not in node.inputs:
            copy_stack_trace(node.outputs[0], var)
    return [result]
Esempio n. 10
0
def batch_normalization_train(
    inputs,
    gamma,
    beta,
    axes="per-activation",
    epsilon=1e-4,
    running_average_factor=0.1,
    running_mean=None,
    running_var=None,
):
    """
    Performs batch normalization of the given inputs, using the mean and
    variance of the inputs.

    Parameters
    ----------
    axes : 'per-activation', 'spatial' or a tuple of ints
        The axes along which the input should be normalized. ``'per-activation'``
        normalizes per activation and is equal to ``axes=(0,)``.
        ``'spatial'`` shares normalization factors across spatial dimensions
        (i.e., all dimensions past the second), which for 4D inputs would be
        equal to ``axes=(0, 2, 3)``.
    gamma : tensor
        Learnable scale factors. The shape must match the shape of `inputs`,
        except for the axes in `axes`. These axes should be set to 1 or be
        skipped altogether (such that `gamma.ndim == inputs.ndim - len(axes)`).
    beta : tensor
        Learnable biases. Must match the tensor layout of `gamma`.
    epsilon : float
        Epsilon value used in the batch normalization formula. Minimum allowed
        value is 1e-5 (imposed by cuDNN).
    running_average_factor : float
        Factor for updating the values or `running_mean` and `running_var`.
        If the factor is close to one, the running averages will update quickly,
        if the factor is close to zero it will update slowly.
    running_mean : tensor or None
        Previous value of the running mean. If this is given, the new value
        ``running_mean * (1 - r_a_factor) + batch mean * r_a_factor``
        will be returned as one of the outputs of this function.
        `running_mean` and `running_var` should either both be given or
        both be None. The shape should match that of `gamma` and `beta`.
    running_var : tensor or None
        Previous value of the running variance. If this is given, the new value
        ``running_var * (1 - r_a_factor) + (m / (m - 1)) * batch var * r_a_factor``
        will be returned as one of the outputs of this function,
        where `m` is the product of lengths of the averaged-over dimensions.
        `running_mean` and `running_var` should either both be given or
        both be None. The shape should match that of `gamma` and `beta`.

    Returns
    -------
    out : tensor
        Batch-normalized inputs.
    mean : tensor
        Means of `inputs` across the normalization axes.
    invstd : tensor
        Inverse standard deviations of `inputs` across the normalization axes.
    new_running_mean : tensor
        New value of the running mean (only if both `running_mean` and
        `running_var` were given).
    new_running_var : tensor
        New value of the running variance (only if both `running_var` and
        `running_mean` were given).

    Notes
    -----
    If per-activation or spatial normalization is selected, this operation
    will use the cuDNN implementation. (This requires cuDNN 5 or newer.)

    The returned values are equivalent to:

    .. code-block:: python

        # for per-activation normalization
        axes = (0,)
        # for spatial normalization
        axes = (0,) + tuple(range(2, inputs.ndim))
        mean = inputs.mean(axes, keepdims=True)
        var = inputs.var(axes, keepdims=True)
        invstd = T.inv(T.sqrt(var + epsilon))
        out = (inputs - mean) * gamma * invstd + beta

        m = T.cast(T.prod(inputs.shape) / T.prod(mean.shape), 'float32')
        running_mean = running_mean * (1 - running_average_factor) + \\
                       mean * running_average_factor
        running_var = running_var * (1 - running_average_factor) + \\
                      (m / (m - 1)) * var * running_average_factor
    """
    ndim = inputs.ndim
    axes, non_bc_axes = _prepare_batch_normalization_axes(axes, ndim)

    # have the parameter tensors been broadcasted yet?
    if gamma.ndim == ndim:
        params_ndim = ndim
    else:
        params_ndim = len(non_bc_axes)
        params_dimshuffle_pattern = ["x"] * ndim
        for i, axis in enumerate(non_bc_axes):
            params_dimshuffle_pattern[axis] = i

    if gamma.ndim != params_ndim or beta.ndim != params_ndim:
        raise ValueError(
            "gamma and beta dimensionality must match the "
            "number of non-normalized axes, or have the "
            "same number of dimensions as the inputs; "
            f"got {int(gamma.ndim)} and {int(beta.ndim)} instead of {int(params_ndim)}"
        )
    if (running_mean is None) != (running_var is None):
        raise ValueError(
            "running_mean and running_var must either both be " "given or both be None"
        )
    if running_mean is not None and running_mean.ndim != params_ndim:
        raise ValueError(
            "running_mean must be of the same dimensionality "
            f"as gamma and beta; got {int(running_mean.ndim)} instead of {int(params_ndim)}"
        )
    if running_var is not None and running_var.ndim != params_ndim:
        raise ValueError(
            "running_var must be of the same dimensionality "
            f"as gamma and beta; got {int(running_var.ndim)} instead of {int(params_ndim)}"
        )

    # epsilon will be converted to floatX later. we need to check
    # for rounding errors now, since numpy.float32(1e-5) < 1e-5.
    epsilon = np.cast[theano.config.floatX](epsilon)
    if epsilon < 1e-5:
        raise ValueError(f"epsilon must be at least 1e-5, got {epsilon}")

    inputs = as_tensor_variable(inputs)
    gamma = as_tensor_variable(gamma)
    beta = as_tensor_variable(beta)

    if params_ndim != ndim:
        gamma = gamma.dimshuffle(params_dimshuffle_pattern)
        beta = beta.dimshuffle(params_dimshuffle_pattern)
    else:
        gamma = tt.addbroadcast(gamma, *axes)
        beta = tt.addbroadcast(beta, *axes)

    batchnorm_op = AbstractBatchNormTrain(axes=axes)

    if running_mean is not None and running_var is not None:
        running_mean = as_tensor_variable(running_mean)
        running_var = as_tensor_variable(running_var)
        if params_ndim != ndim:
            running_mean = running_mean.dimshuffle(params_dimshuffle_pattern)
            running_var = running_var.dimshuffle(params_dimshuffle_pattern)
        else:
            running_mean = tt.addbroadcast(running_mean, *axes)
            running_var = tt.addbroadcast(running_var, *axes)
        out, mean, invstd, new_running_mean, new_running_var = batchnorm_op(
            inputs,
            gamma,
            beta,
            epsilon=epsilon,
            running_average_factor=running_average_factor,
            running_mean=running_mean,
            running_var=running_var,
        )
        if new_running_mean.broadcastable != running_mean.broadcastable:
            new_running_mean = tt.patternbroadcast(
                new_running_mean, running_mean.broadcastable
            )
        if new_running_var.broadcastable != running_var.broadcastable:
            new_running_var = tt.patternbroadcast(
                new_running_var, running_var.broadcastable
            )
        results = (out, mean, invstd, new_running_mean, new_running_var)
    else:
        results = batchnorm_op(inputs, gamma, beta, epsilon=epsilon)

    if params_ndim != ndim:
        # remove the broadcasted dimensions (except from the output)
        results = [results[0]] + [r.dimshuffle(non_bc_axes) for r in results[1:]]
    return tuple(results)
Esempio n. 11
0
def batch_normalization_train(inputs, gamma, beta, axes='per-activation',
                              epsilon=1e-4, running_average_factor=0.1,
                              running_mean=None, running_var=None):
    """
    Performs batch normalization of the given inputs, using the mean and
    variance of the inputs.

    Parameters
    ----------
    axes : 'per-activation', 'spatial' or a tuple of ints
        The axes along which the input should be normalized. ``'per-activation'``
        normalizes per activation and is equal to ``axes=(0,)``.
        ``'spatial'`` shares normalization factors across spatial dimensions
        (i.e., all dimensions past the second), which for 4D inputs would be
        equal to ``axes=(0, 2, 3)``.
    gamma : tensor
        Learnable scale factors. The shape must match the shape of `inputs`,
        except for the axes in `axes`. These axes should be set to 1 or be
        skipped altogether (such that `gamma.ndim == inputs.ndim - len(axes)`).
    beta : tensor
        Learnable biases. Must match the tensor layout of `gamma`.
    epsilon : float
        Epsilon value used in the batch normalization formula. Minimum allowed
        value is 1e-5 (imposed by cuDNN).
    running_average_factor : float
        Factor for updating the values or `running_mean` and `running_var`.
        If the factor is close to one, the running averages will update quickly,
        if the factor is close to zero it will update slowly.
    running_mean : tensor or None
        Previous value of the running mean. If this is given, the new value
        ``running_mean * (1 - r_a_factor) + batch mean * r_a_factor``
        will be returned as one of the outputs of this function.
        `running_mean` and `running_var` should either both be given or
        both be None. The shape should match that of `gamma` and `beta`.
    running_var : tensor or None
        Previous value of the running variance. If this is given, the new value
        ``running_var * (1 - r_a_factor) + (m / (m - 1)) * batch var * r_a_factor``
        will be returned as one of the outputs of this function,
        where `m` is the product of lengths of the averaged-over dimensions.
        `running_mean` and `running_var` should either both be given or
        both be None. The shape should match that of `gamma` and `beta`.

    Returns
    -------
    out : tensor
        Batch-normalized inputs.
    mean : tensor
        Means of `inputs` across the normalization axes.
    invstd : tensor
        Inverse standard deviations of `inputs` across the normalization axes.
    new_running_mean : tensor
        New value of the running mean (only if both `running_mean` and
        `running_var` were given).
    new_running_var : tensor
        New value of the running variance (only if both `running_var` and
        `running_mean` were given).

    Notes
    -----
    If per-activation or spatial normalization is selected, this operation
    will use the cuDNN implementation. (This requires cuDNN 5 or newer.)

    The returned values are equivalent to:

    .. code-block:: python

        # for per-activation normalization
        axes = (0,)
        # for spatial normalization
        axes = (0,) + tuple(range(2, inputs.ndim))
        mean = inputs.mean(axes, keepdims=True)
        var = inputs.var(axes, keepdims=True)
        invstd = T.inv(T.sqrt(var + epsilon))
        out = (inputs - mean) * gamma * invstd + beta

        m = T.cast(T.prod(inputs.shape) / T.prod(mean.shape), 'float32')
        running_mean = running_mean * (1 - running_average_factor) + \\
                       mean * running_average_factor
        running_var = running_var * (1 - running_average_factor) + \\
                      (m / (m - 1)) * var * running_average_factor
    """
    ndim = inputs.ndim
    axes, non_bc_axes = _prepare_batch_normalization_axes(axes, ndim)

    # have the parameter tensors been broadcasted yet?
    if gamma.ndim == ndim:
        params_ndim = ndim
    else:
        params_ndim = len(non_bc_axes)
        params_dimshuffle_pattern = ['x'] * ndim
        for i, axis in enumerate(non_bc_axes):
            params_dimshuffle_pattern[axis] = i

    if gamma.ndim != params_ndim or beta.ndim != params_ndim:
        raise ValueError("gamma and beta dimensionality must match the "
                         "number of non-normalized axes, or have the "
                         "same number of dimensions as the inputs; "
                         "got %d and %d instead of %d" %
                         (gamma.ndim, beta.ndim, params_ndim))
    if (running_mean is None) != (running_var is None):
        raise ValueError("running_mean and running_var must either both be "
                         "given or both be None")
    if running_mean is not None and running_mean.ndim != params_ndim:
        raise ValueError("running_mean must be of the same dimensionality "
                         "as gamma and beta; got %d instead of %d" %
                         (running_mean.ndim, params_ndim))
    if running_var is not None and running_var.ndim != params_ndim:
        raise ValueError("running_var must be of the same dimensionality "
                         "as gamma and beta; got %d instead of %d" %
                         (running_var.ndim, params_ndim))

    # epsilon will be converted to floatX later. we need to check
    # for rounding errors now, since numpy.float32(1e-5) < 1e-5.
    epsilon = np.cast[theano.config.floatX](epsilon)
    if epsilon < 1e-5:
        raise ValueError("epsilon must be at least 1e-5, got %s" % str(epsilon))

    inputs = as_tensor_variable(inputs)
    gamma = as_tensor_variable(gamma)
    beta = as_tensor_variable(beta)

    if params_ndim != ndim:
        gamma = gamma.dimshuffle(params_dimshuffle_pattern)
        beta = beta.dimshuffle(params_dimshuffle_pattern)
    else:
        gamma = T.addbroadcast(gamma, *axes)
        beta = T.addbroadcast(beta, *axes)

    batchnorm_op = AbstractBatchNormTrain(axes=axes)

    if running_mean is not None and running_var is not None:
        running_mean = as_tensor_variable(running_mean)
        running_var = as_tensor_variable(running_var)
        if params_ndim != ndim:
            running_mean = running_mean.dimshuffle(params_dimshuffle_pattern)
            running_var = running_var.dimshuffle(params_dimshuffle_pattern)
        else:
            running_mean = T.addbroadcast(running_mean, *axes)
            running_var = T.addbroadcast(running_var, *axes)
        out, mean, invstd, new_running_mean, new_running_var = batchnorm_op(
            inputs, gamma, beta, epsilon=epsilon,
            running_average_factor=running_average_factor,
            running_mean=running_mean, running_var=running_var)
        if new_running_mean.broadcastable != running_mean.broadcastable:
            new_running_mean = T.patternbroadcast(new_running_mean, running_mean.broadcastable)
        if new_running_var.broadcastable != running_var.broadcastable:
            new_running_var = T.patternbroadcast(new_running_var, running_var.broadcastable)
        results = (out, mean, invstd, new_running_mean, new_running_var)
    else:
        results = batchnorm_op(inputs, gamma, beta, epsilon=epsilon)

    if params_ndim != ndim:
        # remove the broadcasted dimensions (except from the output)
        results = ([results[0]] +
                   [r.dimshuffle(non_bc_axes) for r in results[1:]])
    return tuple(results)