Example #1
0
    def backward(cty, grad_output):
        # retrieve weight references
        Fm, Gm = cty.Fm, cty.Gm

        # retrieve input and output references
        y, output = cty.saved_tensors
        x1, x2 = torch.chunk(output, 2, dim=1)
        x1, x2 = x1.contiguous(), x2.contiguous()

        # partition output gradient also on channels
        assert (grad_output.shape[1] % 2 == 0)

        with set_grad_enabled(False):
            # recompute y
            z1_stop = x2
            z1_stop.requires_grad = True
            FWeights = [p for p in Fm.parameters()]
            fmr1, fmr2 = Fm.forward(z1_stop)
            y1 = (fmr1 * x1) + fmr2
            gmr1, gmr2 = Gm.forward(y1)
            y2 = (gmr1 * x2) + gmr2

        with set_grad_enabled(True):
            # compute outputs building a sub-graph
            y2.requires_grad = True
            y1.requires_grad = True

            gmr1, gmr2 = Gm.forward(y1)  #
            x2 = (y2 - gmr2) / gmr1
            fmr1, fmr2 = Fm.forward(x2)
            x1 = (y1 - fmr2) / fmr1
            x = torch.cat([x1, x2], dim=1)

            # perform full backward pass on graph...
            dd = torch.autograd.grad(x, (y2, y1) + tuple(Fm.parameters()) +
                                     tuple(Gm.parameters()), grad_output)

            FWgrads = dd[2:2 + len(FWeights)]
            GWgrads = dd[2 + len(FWeights):]
            grad_input = torch.cat([dd[0], dd[1]], dim=1)

            # cleanup sub-graph
            x1.detach_()
            x2.detach_()
            del x1, x2

        # restore input
        yout = torch.cat([y1, y2], dim=1).contiguous()
        y.storage().resize_(int(np.prod(yout.shape)))
        y.set_(yout)

        return (grad_input, None, None) + FWgrads + GWgrads
Example #2
0
    def backward(ctx, grad_output):
        # retrieve weight references
        Fm, Gm = ctx.Fm, ctx.Gm

        # retrieve input and output references
        x, output = ctx.saved_tensors
        y1, y2 = torch.chunk(output, 2, dim=1)
        y1, y2 = y1.contiguous(), y2.contiguous()

        # partition output gradient also on channels
        assert (grad_output.shape[1] % 2 == 0)

        with set_grad_enabled(False):
            # recompute x
            z1_stop = y1
            z1_stop.requires_grad = True
            GWeights = [p for p in Gm.parameters()]
            gmr1, gmr2 = Gm.forward(z1_stop)
            x2 = (y2 - gmr2) / gmr1
            fmr1, fmr2 = Fm.forward(x2)
            x1 = (y1 - fmr2) / fmr1

        with set_grad_enabled(True):
            # compute outputs building a sub-graph
            x1.requires_grad = True
            x2.requires_grad = True

            fmr1, fmr2 = Fm.forward(x2)
            y1 = x1 * fmr1 + fmr2
            gmr1, gmr2 = Gm.forward(y1)
            y2 = x2 * gmr1 + gmr2
            y = torch.cat([y1, y2], dim=1)

            # perform full backward pass on graph...
            dd = torch.autograd.grad(y, (x1, x2) + tuple(Gm.parameters()) +
                                     tuple(Fm.parameters()), grad_output)

            GWgrads = dd[2:2 + len(GWeights)]
            FWgrads = dd[2 + len(GWeights):]
            grad_input = torch.cat([dd[0], dd[1]], dim=1)

            # cleanup sub-graph
            y1.detach_()
            y2.detach_()
            del y1, y2

        # restore input
        xout = torch.cat([x1, x2], dim=1).contiguous()
        x.storage().resize(np.prod(xout.shape))
        x.set_(xout)

        return (grad_input, None, None) + FWgrads + GWgrads
Example #3
0
    def backward(cty, grad_output):

        Fm, Gm = cty.Fm, cty.Gm
        # are all variable objects now
        y, output = cty.saved_tensors

        with torch.no_grad():
            x1, x2 = torch.chunk(output, 2, dim=1)
            x1, x2 = x1.contiguous(), x2.contiguous()

            # partition output gradient also on channels
            assert (grad_output.shape[1] % 2 == 0)
            x1_grad, x2_grad = torch.chunk(grad_output, 2, dim=1)
            x1_grad, x2_grad = x1_grad.contiguous(), x2_grad.contiguous()

        # Recreate computation graphs for functions Gm and Fm with gradient collecting leaf nodes:
        # z1_stop, y1_stop, GW, FW
        # Also recompute inputs (y1, y2) from outputs (x1, x2)
        with set_grad_enabled(True):
            z1_stop = x2.detach()
            z1_stop.requires_grad = True

            F_z1 = Fm.forward(z1_stop)
            y1 = x1 + F_z1
            y1_stop = y1.detach()
            y1_stop.requires_grad = True

            G_y1 = Gm.forward(y1_stop)
            y2 = x2 + G_y1
            y2_stop = y2.detach()
            y2_stop.requires_grad = True

            # restore input
            yout = torch.cat([y1, y2], dim=1).contiguous()
            y.storage().resize_(int(np.prod(yout.shape)))
            y.set_(yout).detach()  # NOTE .detach() is very important here.

            # compute outputs building a sub-graph
            z1 = y2_stop - G_y1
            x1 = y1_stop - F_z1
            x2 = z1

            # calculate the final gradients for the weights and inputs
            dd = torch.autograd.grad(x1, (z1_stop, ) + tuple(Fm.parameters()),
                                     x1_grad)
            z1_grad = dd[0] + x2_grad  # + or - ?
            FWgrads = dd[1:]

            dd = torch.autograd.grad(x2, (y2_stop, y1_stop) +
                                     tuple(Gm.parameters()),
                                     z1_grad,
                                     retain_graph=False)

            GWgrads = dd[2:]
            y1_grad = dd[1] + x1_grad  # + or - ?
            y2_grad = dd[0]

            grad_input = torch.cat([y1_grad, y2_grad], dim=1)

            x1.detach_()
            x2.detach_()
            del x1, x2

        return (grad_input, None, None) + FWgrads + GWgrads
Example #4
0
    def backward(ctx, grad_output):
        Fm, Gm = ctx.Fm, ctx.Gm
        # are all variable objects now
        x, output = ctx.saved_tensors

        with set_grad_enabled(False):
            y1, y2 = torch.chunk(output, 2, dim=1)
            y1, y2 = y1.contiguous(), y2.contiguous()

            # partition output gradient also on channels
            assert (grad_output.shape[1] % 2 == 0)
            y1_grad, y2_grad = torch.chunk(grad_output, 2, dim=1)
            y1_grad, y2_grad = y1_grad.contiguous(), y2_grad.contiguous()

        # Recreate computation graphs for functions Gm and Fm with gradient collecting leaf nodes:
        # z1_stop, x2_stop, GW, FW
        # Also recompute inputs (x1, x2) from outputs (y1, y2)
        with set_grad_enabled(True):
            z1_stop = y1
            z1_stop.requires_grad = True

            G_z11, G_z12 = Gm.forward(z1_stop)
            x2 = (y2 - G_z12) / G_z11
            x2_stop = x2.detach()
            x2_stop.requires_grad = True

            F_x21, F_x22 = Fm.forward(x2_stop)
            x1 = (y1 - F_x22) / F_x21
            x1_stop = x1.detach()
            x1_stop.requires_grad = True

            # restore input
            xout = torch.cat([x1, x2], dim=1).contiguous()
            x.storage().resize(np.prod(xout.shape))
            x.set_(xout).detach()  # NOTE .detach() is very important here.

            # compute outputs building a sub-graph
            z1 = x1_stop * F_x21 + F_x22
            y2_ = x2_stop * G_z11 + G_z12
            y1_ = z1

            # calculate the final gradients for the weights and inputs
            dd = torch.autograd.grad(y2_, (z1_stop, ) + tuple(Gm.parameters()),
                                     y2_grad)
            z1_grad = dd[0] + y1_grad
            GWgrads = dd[1:]

            dd = torch.autograd.grad(y1_, (x1_stop, x2_stop) +
                                     tuple(Fm.parameters()),
                                     z1_grad,
                                     retain_graph=False)

            FWgrads = dd[2:]
            x2_grad = dd[1] + y2_grad
            x1_grad = dd[0]
            grad_input = torch.cat([x1_grad, x2_grad], dim=1)

            y1_.detach_()
            y2_.detach_()
            del y1_, y2_

        return (grad_input, None, None) + FWgrads + GWgrads