Beispiel #1
0
    def grad(self, inp, grads):
        x, dy, scale, x_mean, x_invstd, epsilon = inp
        ddinputs, ddscale, ddbias = grads

        x_diff = x - x_mean
        mean_dy_x_diff = T.mean(dy * x_diff, axis=self.axes, keepdims=True)

        # compute gradients given each of the output gradients
        g_wrt_x = 0
        g_wrt_dy = 0
        g_wrt_scale = 0
        g_wrt_x_mean = 0
        g_wrt_x_invstd = 0

        if not isinstance(ddinputs.type, theano.gradient.DisconnectedType):
            ccc = scale * (ddinputs -
                           T.mean(ddinputs, axis=self.axes, keepdims=True))
            ddd = (x_invstd**3) * (
                ccc * T.mean(dy * x_diff, axis=self.axes, keepdims=True) +
                dy * T.mean(ccc * x_diff, axis=self.axes, keepdims=True))

            g_wrt_x = g_wrt_x - ddd
            g_wrt_dy = g_wrt_dy + ((ccc * x_invstd) - (
                (x_invstd**3) * x_diff *
                T.mean(ccc * x_diff, axis=self.axes, keepdims=True)))

            eee = (dy * x_invstd) - ((x_invstd**3) * x_diff * mean_dy_x_diff)
            g_wrt_scale = g_wrt_scale + T.sum(
                ddinputs * (eee - T.mean(eee, axis=self.axes, keepdims=True)),
                axis=self.axes,
                keepdims=True)

            g_wrt_x_mean = g_wrt_x_mean + T.sum(
                ddd, axis=self.axes, keepdims=True)
            g_wrt_x_invstd = g_wrt_x_invstd + T.sum(
                ccc * (dy - 3 * (x_invstd**2) * x_diff * mean_dy_x_diff),
                axis=self.axes,
                keepdims=True)

        if not isinstance(ddscale.type, theano.gradient.DisconnectedType):
            g_wrt_x = g_wrt_x + (x_invstd * ddscale * dy)
            g_wrt_dy = g_wrt_dy + (x_invstd * ddscale * x_diff)
            g_wrt_x_mean = g_wrt_x_mean - (
                x_invstd * ddscale * T.sum(dy, axis=self.axes, keepdims=True))
            g_wrt_x_invstd = g_wrt_x_invstd + (
                ddscale * T.sum(dy * x_diff, axis=self.axes, keepdims=True))

        if not isinstance(ddbias.type, theano.gradient.DisconnectedType):
            g_wrt_dy = g_wrt_dy + T.fill(dy, ddbias)

        # depending on which output gradients are given,
        # some inputs should be disconnected
        results = [
            g_wrt_x, g_wrt_dy, g_wrt_scale, g_wrt_x_mean, g_wrt_x_invstd,
            theano.gradient.DisconnectedType()()
        ]
        return [
            theano.gradient.DisconnectedType()() if r is 0 else r
            for r in results
        ]
Beispiel #2
0
    def grad(self, inp, grads):
        x, dy, scale, x_mean, x_invstd, epsilon = inp
        ddinputs, ddscale, ddbias = grads

        x_diff = x - x_mean
        mean_dy_x_diff = T.mean(dy * x_diff, axis=self.axes, keepdims=True)

        # compute gradients given each of the output gradients
        g_wrt_x = 0
        g_wrt_dy = 0
        g_wrt_scale = 0
        g_wrt_x_mean = 0
        g_wrt_x_invstd = 0

        if not isinstance(ddinputs.type, theano.gradient.DisconnectedType):
            ccc = scale * (ddinputs - T.mean(ddinputs, axis=self.axes, keepdims=True))
            ddd = (x_invstd ** 3) * (ccc * T.mean(dy * x_diff, axis=self.axes, keepdims=True) +
                                     dy * T.mean(ccc * x_diff, axis=self.axes, keepdims=True))

            g_wrt_x = g_wrt_x - ddd
            g_wrt_dy = g_wrt_dy + ((ccc * x_invstd) -
                                   ((x_invstd ** 3) * x_diff *
                                    T.mean(ccc * x_diff, axis=self.axes, keepdims=True)))

            eee = (dy * x_invstd) - ((x_invstd ** 3) * x_diff * mean_dy_x_diff)
            g_wrt_scale = g_wrt_scale + T.sum(ddinputs * (eee - T.mean(eee, axis=self.axes, keepdims=True)),
                                              axis=self.axes, keepdims=True)

            g_wrt_x_mean = g_wrt_x_mean + T.sum(ddd, axis=self.axes, keepdims=True)
            g_wrt_x_invstd = g_wrt_x_invstd + T.sum(ccc * (dy - 3 * (x_invstd ** 2) * x_diff * mean_dy_x_diff),
                                                    axis=self.axes, keepdims=True)

        if not isinstance(ddscale.type, theano.gradient.DisconnectedType):
            g_wrt_x = g_wrt_x + (x_invstd * ddscale * dy)
            g_wrt_dy = g_wrt_dy + (x_invstd * ddscale * x_diff)
            g_wrt_x_mean = g_wrt_x_mean - (x_invstd * ddscale * T.sum(dy, axis=self.axes, keepdims=True))
            g_wrt_x_invstd = g_wrt_x_invstd + (ddscale * T.sum(dy * x_diff, axis=self.axes, keepdims=True))

        if not isinstance(ddbias.type, theano.gradient.DisconnectedType):
            g_wrt_dy = g_wrt_dy + T.fill(dy, ddbias)

        # depending on which output gradients are given,
        # some inputs should be disconnected
        results = [g_wrt_x, g_wrt_dy, g_wrt_scale, g_wrt_x_mean, g_wrt_x_invstd,
                   theano.gradient.DisconnectedType()()]
        return [theano.gradient.DisconnectedType()() if r is 0 else r
                for r in results]
Beispiel #3
0
def _fill_chain(new_out, orig_inputs):
    for i in orig_inputs:
        new_out = T.fill(i, new_out)
    return [new_out]