Exemplo n.º 1
0
def local_abstract_batch_norm_inference(fgraph, node):
    if not isinstance(node.op, AbstractBatchNormInference):
        return None

    x, scale, bias, estimated_mean, estimated_variance, epsilon = node.inputs

    if (
        not isinstance(x.type, TensorType)
        or not isinstance(scale.type, TensorType)
        or not isinstance(bias.type, TensorType)
        or not isinstance(estimated_mean.type, TensorType)
        or not isinstance(estimated_variance.type, TensorType)
        or not isinstance(epsilon.type, TensorType)
    ):
        return None

    # The epsilon should not upcast the dtype.
    if estimated_variance.dtype == "float32" and epsilon.dtype == "float64":
        epsilon = epsilon.astype("float32")

    result = (x - estimated_mean) * (scale / sqrt(estimated_variance + epsilon)) + bias
    result = aet.patternbroadcast(result, node.outputs[0].broadcastable)

    for var in aesara.graph.basic.vars_between(node.inputs, [result]):
        if var not in node.inputs:
            copy_stack_trace(node.outputs[0], var)
    return [result]
Exemplo n.º 2
0
def local_abstract_batch_norm_train_grad(fgraph, node):
    if not isinstance(node.op, AbstractBatchNormTrainGrad):
        return None

    x, dy, scale, x_mean, x_invstd, epsilon = node.inputs
    axes = node.op.axes
    if min(axes) < 0 or max(axes) > x.ndim:
        return None
    if (not isinstance(x.type, TensorType)
            or not isinstance(dy.type, TensorType)
            or not isinstance(scale.type, TensorType)
            or not isinstance(x_mean.type, TensorType)
            or not isinstance(x_invstd.type, TensorType)
            or not isinstance(epsilon.type, TensorType)):
        return None

    x_diff = x - x_mean
    mean_dy_x_diff = mean(dy * x_diff, axis=axes, keepdims=True)
    c = (dy * x_invstd) - x_diff * (mean_dy_x_diff * (x_invstd**3))

    g_wrt_inputs = scale * (c - mean(c, axis=axes, keepdims=True))
    g_wrt_scale = aet_sum(dy * x_invstd * x_diff, axis=axes, keepdims=True)
    g_wrt_bias = aet_sum(dy, axis=axes, keepdims=True)
    results = [g_wrt_inputs, g_wrt_scale, g_wrt_bias]

    results = [
        aet.patternbroadcast(r, r_orig.broadcastable)
        for (r, r_orig) in zip(results, node.outputs)
    ]

    for var in aesara.graph.basic.vars_between(node.inputs, results):
        if var not in node.inputs:
            copy_stack_trace(node.outputs[0], var)
    return results
Exemplo n.º 3
0
def local_abstract_batch_norm_train(fgraph, node):
    if not isinstance(node.op, AbstractBatchNormTrain):
        return None

    x, scale, bias, epsilon, running_average_factor = node.inputs[:5]
    axes = node.op.axes
    if min(axes) < 0 or max(axes) > x.ndim:
        return None
    if (
        not isinstance(x.type, TensorType)
        or not isinstance(scale.type, TensorType)
        or not isinstance(bias.type, TensorType)
        or not isinstance(epsilon.type, TensorType)
        or not isinstance(running_average_factor.type, TensorType)
    ):
        return None
    # optional running_mean and running_var
    if len(node.inputs) > 5 and not isinstance(node.inputs[5].type, TensorType):
        return None
    if len(node.inputs) > 6 and not isinstance(node.inputs[6].type, TensorType):
        return None

    mean = x.mean(axes, keepdims=True)
    var = x.var(axes, keepdims=True)
    # The epsilon should not upcast the dtype.
    if var.dtype == "float32" and epsilon.dtype == "float64":
        epsilon = epsilon.astype("float32")
    invstd = inv(sqrt(var + epsilon))
    out = (x - mean) * (scale * invstd) + bias
    results = [out, mean, invstd]

    if len(node.inputs) > 5:
        running_mean = node.inputs[5]
        running_mean = (
            running_mean * (1.0 - running_average_factor)
            + mean * running_average_factor
        )
        results.append(running_mean)
    if len(node.inputs) > 6:
        m = aet.cast(prod(x.shape) / prod(scale.shape), config.floatX)
        running_var = node.inputs[6]
        running_var = (
            running_var * (1.0 - running_average_factor)
            + (m / (m - 1)) * var * running_average_factor
        )
        results.append(running_var)

    results = [
        aet.patternbroadcast(r, r_orig.broadcastable)
        for (r, r_orig) in zip(results, node.outputs)
    ]

    for var in aesara.graph.basic.vars_between(node.inputs, results):
        if var not in node.inputs:
            copy_stack_trace(node.outputs[0], var)
    return results
Exemplo n.º 4
0
def batch_normalization_train(
    inputs,
    gamma,
    beta,
    axes="per-activation",
    epsilon=1e-4,
    running_average_factor=0.1,
    running_mean=None,
    running_var=None,
):
    """
    Performs batch normalization of the given inputs, using the mean and
    variance of the inputs.

    Parameters
    ----------
    axes : 'per-activation', 'spatial' or a tuple of ints
        The axes along which the input should be normalized. ``'per-activation'``
        normalizes per activation and is equal to ``axes=(0,)``.
        ``'spatial'`` shares normalization factors across spatial dimensions
        (i.e., all dimensions past the second), which for 4D inputs would be
        equal to ``axes=(0, 2, 3)``.
    gamma : tensor
        Learnable scale factors. The shape must match the shape of `inputs`,
        except for the axes in `axes`. These axes should be set to 1 or be
        skipped altogether (such that `gamma.ndim == inputs.ndim - len(axes)`).
    beta : tensor
        Learnable biases. Must match the tensor layout of `gamma`.
    epsilon : float
        Epsilon value used in the batch normalization formula. Minimum allowed
        value is 1e-5 (imposed by cuDNN).
    running_average_factor : float
        Factor for updating the values or `running_mean` and `running_var`.
        If the factor is close to one, the running averages will update quickly,
        if the factor is close to zero it will update slowly.
    running_mean : tensor or None
        Previous value of the running mean. If this is given, the new value
        ``running_mean * (1 - r_a_factor) + batch mean * r_a_factor``
        will be returned as one of the outputs of this function.
        `running_mean` and `running_var` should either both be given or
        both be None. The shape should match that of `gamma` and `beta`.
    running_var : tensor or None
        Previous value of the running variance. If this is given, the new value
        ``running_var * (1 - r_a_factor) + (m / (m - 1)) * batch var * r_a_factor``
        will be returned as one of the outputs of this function,
        where `m` is the product of lengths of the averaged-over dimensions.
        `running_mean` and `running_var` should either both be given or
        both be None. The shape should match that of `gamma` and `beta`.

    Returns
    -------
    out : tensor
        Batch-normalized inputs.
    mean : tensor
        Means of `inputs` across the normalization axes.
    invstd : tensor
        Inverse standard deviations of `inputs` across the normalization axes.
    new_running_mean : tensor
        New value of the running mean (only if both `running_mean` and
        `running_var` were given).
    new_running_var : tensor
        New value of the running variance (only if both `running_var` and
        `running_mean` were given).

    Notes
    -----
    If per-activation or spatial normalization is selected, this operation
    will use the cuDNN implementation. (This requires cuDNN 5 or newer.)

    The returned values are equivalent to:

    .. code-block:: python

        # for per-activation normalization
        axes = (0,)
        # for spatial normalization
        axes = (0,) + tuple(range(2, inputs.ndim))
        mean = inputs.mean(axes, keepdims=True)
        var = inputs.var(axes, keepdims=True)
        invstd = T.inv(T.sqrt(var + epsilon))
        out = (inputs - mean) * gamma * invstd + beta

        m = T.cast(T.prod(inputs.shape) / T.prod(mean.shape), 'float32')
        running_mean = running_mean * (1 - running_average_factor) + \\
                       mean * running_average_factor
        running_var = running_var * (1 - running_average_factor) + \\
                      (m / (m - 1)) * var * running_average_factor
    """
    ndim = inputs.ndim
    axes, non_bc_axes = _prepare_batch_normalization_axes(axes, ndim)

    # have the parameter tensors been broadcasted yet?
    if gamma.ndim == ndim:
        params_ndim = ndim
    else:
        params_ndim = len(non_bc_axes)
        params_dimshuffle_pattern = ["x"] * ndim
        for i, axis in enumerate(non_bc_axes):
            params_dimshuffle_pattern[axis] = i

    if gamma.ndim != params_ndim or beta.ndim != params_ndim:
        raise ValueError(
            "gamma and beta dimensionality must match the "
            "number of non-normalized axes, or have the "
            "same number of dimensions as the inputs; "
            f"got {int(gamma.ndim)} and {int(beta.ndim)} instead of {int(params_ndim)}"
        )
    if (running_mean is None) != (running_var is None):
        raise ValueError(
            "running_mean and running_var must either both be " "given or both be None"
        )
    if running_mean is not None and running_mean.ndim != params_ndim:
        raise ValueError(
            "running_mean must be of the same dimensionality "
            f"as gamma and beta; got {int(running_mean.ndim)} instead of {int(params_ndim)}"
        )
    if running_var is not None and running_var.ndim != params_ndim:
        raise ValueError(
            "running_var must be of the same dimensionality "
            f"as gamma and beta; got {int(running_var.ndim)} instead of {int(params_ndim)}"
        )

    # epsilon will be converted to floatX later. we need to check
    # for rounding errors now, since numpy.float32(1e-5) < 1e-5.
    epsilon = np.cast[config.floatX](epsilon)
    if epsilon < 1e-5:
        raise ValueError(f"epsilon must be at least 1e-5, got {epsilon}")

    inputs = as_tensor_variable(inputs)
    gamma = as_tensor_variable(gamma)
    beta = as_tensor_variable(beta)

    if params_ndim != ndim:
        gamma = gamma.dimshuffle(params_dimshuffle_pattern)
        beta = beta.dimshuffle(params_dimshuffle_pattern)
    else:
        gamma = aet.addbroadcast(gamma, *axes)
        beta = aet.addbroadcast(beta, *axes)

    batchnorm_op = AbstractBatchNormTrain(axes=axes)

    if running_mean is not None and running_var is not None:
        running_mean = as_tensor_variable(running_mean)
        running_var = as_tensor_variable(running_var)
        if params_ndim != ndim:
            running_mean = running_mean.dimshuffle(params_dimshuffle_pattern)
            running_var = running_var.dimshuffle(params_dimshuffle_pattern)
        else:
            running_mean = aet.addbroadcast(running_mean, *axes)
            running_var = aet.addbroadcast(running_var, *axes)
        out, mean, invstd, new_running_mean, new_running_var = batchnorm_op(
            inputs,
            gamma,
            beta,
            epsilon=epsilon,
            running_average_factor=running_average_factor,
            running_mean=running_mean,
            running_var=running_var,
        )
        if new_running_mean.broadcastable != running_mean.broadcastable:
            new_running_mean = aet.patternbroadcast(
                new_running_mean, running_mean.broadcastable
            )
        if new_running_var.broadcastable != running_var.broadcastable:
            new_running_var = aet.patternbroadcast(
                new_running_var, running_var.broadcastable
            )
        results = (out, mean, invstd, new_running_mean, new_running_var)
    else:
        results = batchnorm_op(inputs, gamma, beta, epsilon=epsilon)

    if params_ndim != ndim:
        # remove the broadcasted dimensions (except from the output)
        results = [results[0]] + [r.dimshuffle(non_bc_axes) for r in results[1:]]
    return tuple(results)