def backward(cty, grad_output): # retrieve weight references Fm, Gm = cty.Fm, cty.Gm # retrieve input and output references y, output = cty.saved_tensors x1, x2 = torch.chunk(output, 2, dim=1) x1, x2 = x1.contiguous(), x2.contiguous() # partition output gradient also on channels assert (grad_output.shape[1] % 2 == 0) with set_grad_enabled(False): # recompute y z1_stop = x2 z1_stop.requires_grad = True FWeights = [p for p in Fm.parameters()] fmr1, fmr2 = Fm.forward(z1_stop) y1 = (fmr1 * x1) + fmr2 gmr1, gmr2 = Gm.forward(y1) y2 = (gmr1 * x2) + gmr2 with set_grad_enabled(True): # compute outputs building a sub-graph y2.requires_grad = True y1.requires_grad = True gmr1, gmr2 = Gm.forward(y1) # x2 = (y2 - gmr2) / gmr1 fmr1, fmr2 = Fm.forward(x2) x1 = (y1 - fmr2) / fmr1 x = torch.cat([x1, x2], dim=1) # perform full backward pass on graph... dd = torch.autograd.grad(x, (y2, y1) + tuple(Fm.parameters()) + tuple(Gm.parameters()), grad_output) FWgrads = dd[2:2 + len(FWeights)] GWgrads = dd[2 + len(FWeights):] grad_input = torch.cat([dd[0], dd[1]], dim=1) # cleanup sub-graph x1.detach_() x2.detach_() del x1, x2 # restore input yout = torch.cat([y1, y2], dim=1).contiguous() y.storage().resize_(int(np.prod(yout.shape))) y.set_(yout) return (grad_input, None, None) + FWgrads + GWgrads
def backward(ctx, grad_output): # retrieve weight references Fm, Gm = ctx.Fm, ctx.Gm # retrieve input and output references x, output = ctx.saved_tensors y1, y2 = torch.chunk(output, 2, dim=1) y1, y2 = y1.contiguous(), y2.contiguous() # partition output gradient also on channels assert (grad_output.shape[1] % 2 == 0) with set_grad_enabled(False): # recompute x z1_stop = y1 z1_stop.requires_grad = True GWeights = [p for p in Gm.parameters()] gmr1, gmr2 = Gm.forward(z1_stop) x2 = (y2 - gmr2) / gmr1 fmr1, fmr2 = Fm.forward(x2) x1 = (y1 - fmr2) / fmr1 with set_grad_enabled(True): # compute outputs building a sub-graph x1.requires_grad = True x2.requires_grad = True fmr1, fmr2 = Fm.forward(x2) y1 = x1 * fmr1 + fmr2 gmr1, gmr2 = Gm.forward(y1) y2 = x2 * gmr1 + gmr2 y = torch.cat([y1, y2], dim=1) # perform full backward pass on graph... dd = torch.autograd.grad(y, (x1, x2) + tuple(Gm.parameters()) + tuple(Fm.parameters()), grad_output) GWgrads = dd[2:2 + len(GWeights)] FWgrads = dd[2 + len(GWeights):] grad_input = torch.cat([dd[0], dd[1]], dim=1) # cleanup sub-graph y1.detach_() y2.detach_() del y1, y2 # restore input xout = torch.cat([x1, x2], dim=1).contiguous() x.storage().resize(np.prod(xout.shape)) x.set_(xout) return (grad_input, None, None) + FWgrads + GWgrads
def backward(cty, grad_output): Fm, Gm = cty.Fm, cty.Gm # are all variable objects now y, output = cty.saved_tensors with torch.no_grad(): x1, x2 = torch.chunk(output, 2, dim=1) x1, x2 = x1.contiguous(), x2.contiguous() # partition output gradient also on channels assert (grad_output.shape[1] % 2 == 0) x1_grad, x2_grad = torch.chunk(grad_output, 2, dim=1) x1_grad, x2_grad = x1_grad.contiguous(), x2_grad.contiguous() # Recreate computation graphs for functions Gm and Fm with gradient collecting leaf nodes: # z1_stop, y1_stop, GW, FW # Also recompute inputs (y1, y2) from outputs (x1, x2) with set_grad_enabled(True): z1_stop = x2.detach() z1_stop.requires_grad = True F_z1 = Fm.forward(z1_stop) y1 = x1 + F_z1 y1_stop = y1.detach() y1_stop.requires_grad = True G_y1 = Gm.forward(y1_stop) y2 = x2 + G_y1 y2_stop = y2.detach() y2_stop.requires_grad = True # restore input yout = torch.cat([y1, y2], dim=1).contiguous() y.storage().resize_(int(np.prod(yout.shape))) y.set_(yout).detach() # NOTE .detach() is very important here. # compute outputs building a sub-graph z1 = y2_stop - G_y1 x1 = y1_stop - F_z1 x2 = z1 # calculate the final gradients for the weights and inputs dd = torch.autograd.grad(x1, (z1_stop, ) + tuple(Fm.parameters()), x1_grad) z1_grad = dd[0] + x2_grad # + or - ? FWgrads = dd[1:] dd = torch.autograd.grad(x2, (y2_stop, y1_stop) + tuple(Gm.parameters()), z1_grad, retain_graph=False) GWgrads = dd[2:] y1_grad = dd[1] + x1_grad # + or - ? y2_grad = dd[0] grad_input = torch.cat([y1_grad, y2_grad], dim=1) x1.detach_() x2.detach_() del x1, x2 return (grad_input, None, None) + FWgrads + GWgrads
def backward(ctx, grad_output): Fm, Gm = ctx.Fm, ctx.Gm # are all variable objects now x, output = ctx.saved_tensors with set_grad_enabled(False): y1, y2 = torch.chunk(output, 2, dim=1) y1, y2 = y1.contiguous(), y2.contiguous() # partition output gradient also on channels assert (grad_output.shape[1] % 2 == 0) y1_grad, y2_grad = torch.chunk(grad_output, 2, dim=1) y1_grad, y2_grad = y1_grad.contiguous(), y2_grad.contiguous() # Recreate computation graphs for functions Gm and Fm with gradient collecting leaf nodes: # z1_stop, x2_stop, GW, FW # Also recompute inputs (x1, x2) from outputs (y1, y2) with set_grad_enabled(True): z1_stop = y1 z1_stop.requires_grad = True G_z11, G_z12 = Gm.forward(z1_stop) x2 = (y2 - G_z12) / G_z11 x2_stop = x2.detach() x2_stop.requires_grad = True F_x21, F_x22 = Fm.forward(x2_stop) x1 = (y1 - F_x22) / F_x21 x1_stop = x1.detach() x1_stop.requires_grad = True # restore input xout = torch.cat([x1, x2], dim=1).contiguous() x.storage().resize(np.prod(xout.shape)) x.set_(xout).detach() # NOTE .detach() is very important here. # compute outputs building a sub-graph z1 = x1_stop * F_x21 + F_x22 y2_ = x2_stop * G_z11 + G_z12 y1_ = z1 # calculate the final gradients for the weights and inputs dd = torch.autograd.grad(y2_, (z1_stop, ) + tuple(Gm.parameters()), y2_grad) z1_grad = dd[0] + y1_grad GWgrads = dd[1:] dd = torch.autograd.grad(y1_, (x1_stop, x2_stop) + tuple(Fm.parameters()), z1_grad, retain_graph=False) FWgrads = dd[2:] x2_grad = dd[1] + y2_grad x1_grad = dd[0] grad_input = torch.cat([x1_grad, x2_grad], dim=1) y1_.detach_() y2_.detach_() del y1_, y2_ return (grad_input, None, None) + FWgrads + GWgrads