Beispiel #1
0
def local_abstract_batch_norm_train_grad(node):
    if not isinstance(node.op, AbstractBatchNormTrainGrad):
        return None

    x, dy, scale, x_mean, x_invstd, epsilon = node.inputs
    axes = node.op.axes
    if min(axes) < 0 or max(axes) > x.ndim:
        return None
    if not isinstance(x.type, TensorType) or \
       not isinstance(dy.type, TensorType) or \
       not isinstance(scale.type, TensorType) or \
       not isinstance(x_mean.type, TensorType) or \
       not isinstance(x_invstd.type, TensorType) or \
       not isinstance(epsilon.type, TensorType):
        return None

    x_diff = x - x_mean
    mean_dy_x_diff = T.mean(dy * x_diff, axis=axes, keepdims=True)
    c = (dy * x_invstd) - x_diff * (mean_dy_x_diff * (x_invstd ** 3))

    g_wrt_inputs = scale * (c - T.mean(c, axis=axes, keepdims=True))
    g_wrt_scale = T.sum(dy * x_invstd * x_diff, axis=axes, keepdims=True)
    g_wrt_bias = T.sum(dy, axis=axes, keepdims=True)
    results = [g_wrt_inputs, g_wrt_scale, g_wrt_bias]

    results = [T.patternbroadcast(r, r_orig.broadcastable)
               for (r, r_orig) in zip(results, node.outputs)]

    for var in theano.gof.graph.variables(node.inputs, results):
        if var not in node.inputs:
            copy_stack_trace(node.outputs[0], var)
    return results
Beispiel #2
0
    def grad(self, inp, grads):
        x, dy, scale, x_mean, x_invstd, epsilon = inp
        ddinputs, ddscale, ddbias = grads

        x_diff = x - x_mean
        mean_dy_x_diff = T.mean(dy * x_diff, axis=self.axes, keepdims=True)

        # compute gradients given each of the output gradients
        g_wrt_x = 0
        g_wrt_dy = 0
        g_wrt_scale = 0
        g_wrt_x_mean = 0
        g_wrt_x_invstd = 0

        if not isinstance(ddinputs.type, theano.gradient.DisconnectedType):
            ccc = scale * (ddinputs -
                           T.mean(ddinputs, axis=self.axes, keepdims=True))
            ddd = (x_invstd**3) * (
                ccc * T.mean(dy * x_diff, axis=self.axes, keepdims=True) +
                dy * T.mean(ccc * x_diff, axis=self.axes, keepdims=True))

            g_wrt_x = g_wrt_x - ddd
            g_wrt_dy = g_wrt_dy + ((ccc * x_invstd) - (
                (x_invstd**3) * x_diff *
                T.mean(ccc * x_diff, axis=self.axes, keepdims=True)))

            eee = (dy * x_invstd) - ((x_invstd**3) * x_diff * mean_dy_x_diff)
            g_wrt_scale = g_wrt_scale + T.sum(
                ddinputs * (eee - T.mean(eee, axis=self.axes, keepdims=True)),
                axis=self.axes,
                keepdims=True)

            g_wrt_x_mean = g_wrt_x_mean + T.sum(
                ddd, axis=self.axes, keepdims=True)
            g_wrt_x_invstd = g_wrt_x_invstd + T.sum(
                ccc * (dy - 3 * (x_invstd**2) * x_diff * mean_dy_x_diff),
                axis=self.axes,
                keepdims=True)

        if not isinstance(ddscale.type, theano.gradient.DisconnectedType):
            g_wrt_x = g_wrt_x + (x_invstd * ddscale * dy)
            g_wrt_dy = g_wrt_dy + (x_invstd * ddscale * x_diff)
            g_wrt_x_mean = g_wrt_x_mean - (
                x_invstd * ddscale * T.sum(dy, axis=self.axes, keepdims=True))
            g_wrt_x_invstd = g_wrt_x_invstd + (
                ddscale * T.sum(dy * x_diff, axis=self.axes, keepdims=True))

        if not isinstance(ddbias.type, theano.gradient.DisconnectedType):
            g_wrt_dy = g_wrt_dy + T.fill(dy, ddbias)

        # depending on which output gradients are given,
        # some inputs should be disconnected
        results = [
            g_wrt_x, g_wrt_dy, g_wrt_scale, g_wrt_x_mean, g_wrt_x_invstd,
            theano.gradient.DisconnectedType()()
        ]
        return [
            theano.gradient.DisconnectedType()() if r is 0 else r
            for r in results
        ]
Beispiel #3
0
    def grad(self, inp, grads):
        x, dy, scale, x_mean, x_invstd, epsilon = inp
        ddinputs, ddscale, ddbias = grads

        x_diff = x - x_mean
        mean_dy_x_diff = T.mean(dy * x_diff, axis=self.axes, keepdims=True)

        # compute gradients given each of the output gradients
        g_wrt_x = 0
        g_wrt_dy = 0
        g_wrt_scale = 0
        g_wrt_x_mean = 0
        g_wrt_x_invstd = 0

        if not isinstance(ddinputs.type, theano.gradient.DisconnectedType):
            ccc = scale * (ddinputs - T.mean(ddinputs, axis=self.axes, keepdims=True))
            ddd = (x_invstd ** 3) * (ccc * T.mean(dy * x_diff, axis=self.axes, keepdims=True) +
                                     dy * T.mean(ccc * x_diff, axis=self.axes, keepdims=True))

            g_wrt_x = g_wrt_x - ddd
            g_wrt_dy = g_wrt_dy + ((ccc * x_invstd) -
                                   ((x_invstd ** 3) * x_diff *
                                    T.mean(ccc * x_diff, axis=self.axes, keepdims=True)))

            eee = (dy * x_invstd) - ((x_invstd ** 3) * x_diff * mean_dy_x_diff)
            g_wrt_scale = g_wrt_scale + T.sum(ddinputs * (eee - T.mean(eee, axis=self.axes, keepdims=True)),
                                              axis=self.axes, keepdims=True)

            g_wrt_x_mean = g_wrt_x_mean + T.sum(ddd, axis=self.axes, keepdims=True)
            g_wrt_x_invstd = g_wrt_x_invstd + T.sum(ccc * (dy - 3 * (x_invstd ** 2) * x_diff * mean_dy_x_diff),
                                                    axis=self.axes, keepdims=True)

        if not isinstance(ddscale.type, theano.gradient.DisconnectedType):
            g_wrt_x = g_wrt_x + (x_invstd * ddscale * dy)
            g_wrt_dy = g_wrt_dy + (x_invstd * ddscale * x_diff)
            g_wrt_x_mean = g_wrt_x_mean - (x_invstd * ddscale * T.sum(dy, axis=self.axes, keepdims=True))
            g_wrt_x_invstd = g_wrt_x_invstd + (ddscale * T.sum(dy * x_diff, axis=self.axes, keepdims=True))

        if not isinstance(ddbias.type, theano.gradient.DisconnectedType):
            g_wrt_dy = g_wrt_dy + T.fill(dy, ddbias)

        # depending on which output gradients are given,
        # some inputs should be disconnected
        results = [g_wrt_x, g_wrt_dy, g_wrt_scale, g_wrt_x_mean, g_wrt_x_invstd,
                   theano.gradient.DisconnectedType()()]
        return [theano.gradient.DisconnectedType()() if r is 0 else r
                for r in results]