def broyden_find_root(func, z1ss, uss, z0, eps, *args): bsz, d_model, seq_len = z1ss.size() z1ss_est = z1ss.clone().detach() threshold = args[ -2] # Can also set this to be different, based on training/inference train_step = args[-1] g = lambda x: DEQFunc.g(func, x, uss, z0, *args) result_info = broyden(g, z1ss_est, threshold=threshold, eps=eps, name="forward") z1ss_est = result_info['result'] nstep = result_info['nstep'] DEQFunc.forward_step = result_info['nstep'] DEQFunc.forward_reduced_ratio = result_info['reduced_ratio'] DEQFunc.forward_lowest = result_info['diff'] DEQFunc.forward_init = result_info['init'] DEQFunc.forward_obj_trace = result_info['trace'] if threshold > 100: torch.cuda.empty_cache() return z1ss_est.clone().detach()
def backward(ctx, grad): torch.cuda.empty_cache() # grad should have dimension (bsz x d_model x seq_len) bsz, d_model, seq_len = grad.size() grad = grad.clone() z1ss, uss, z0 = ctx.saved_tensors args = ctx.args threshold = args[-2] train_step = args[-1] func = ctx.func z1ss_temp = z1ss.clone().detach().requires_grad_() uss_temp = uss.clone().detach() z0_temp = z0.clone().detach() args_temp = copy.deepcopy(args) with torch.enable_grad(): y = DEQFunc.g(func, z1ss_temp, uss_temp, z0_temp, *args_temp) def g(x): y.backward(x, retain_graph=True) # Retain for future calls to g JTx = z1ss_temp.grad.clone().detach() z1ss_temp.grad.zero_() return JTx + grad eps = 2e-10 * np.sqrt(bsz * seq_len * d_model) dl_df_est = torch.zeros_like(grad) result_info = broyden(g, dl_df_est, threshold=threshold, eps=eps, name="backward") dl_df_est = result_info['result'] nstep = result_info['nstep'] DummyDEQFunc.backward_step = result_info['nstep'] DummyDEQFunc.backward_reduced_ratio = result_info['reduced_ratio'] DummyDEQFunc.backward_lowest = result_info['diff'] DummyDEQFunc.backward_init = result_info['init'] DummyDEQFunc.backward_obj_trace = result_info['trace'] y.backward(torch.zeros_like(dl_df_est), retain_graph=False) grad_args = [None for _ in range(len(args))] return (None, dl_df_est, None, None, *grad_args)
def broyden_find_root(func, z1ss, uss, z0, eps, *args): bsz, d_model, seq_len = z1ss.size() z1ss_est = z1ss.clone().detach() threshold = args[ -2] # Can also set this to be different, based on training/inference train_step = args[-1] g = lambda x: RootFind.g(func, x, uss, z0, *args) result_info = broyden(g, z1ss_est, threshold=threshold, eps=eps, name="forward") z1ss_est = result_info['result'] if threshold > 100: torch.cuda.empty_cache() return z1ss_est.clone().detach()
def backward(ctx, grad): torch.cuda.empty_cache() # grad should have dimension (bsz x d_model x seq_len) bsz, d_model, seq_len = grad.size() grad = grad.clone() z1ss, uss, z0 = ctx.saved_tensors args = ctx.args threshold, train_step = args[-2:] func = ctx.func z1ss = z1ss.clone().detach().requires_grad_() uss = uss.clone().detach() z0 = z0.clone().detach() with torch.enable_grad(): y = RootFind.g(func, z1ss, uss, z0, *args) def g(x): y.backward(x, retain_graph=True) # Retain for future calls to g JTx = z1ss.grad.clone().detach() z1ss.grad.zero_() return JTx + grad eps = 2e-10 * np.sqrt(bsz * seq_len * d_model) dl_df_est = torch.zeros_like(grad) result_info = broyden(g, dl_df_est, threshold=threshold, eps=eps, name="backward") dl_df_est = result_info['result'] nstep = result_info['nstep'] if dl_df_est.get_device() == 0 and np.random.uniform(0, 1) < 1e-4: msg = f"{nstep} steps in Broyden backward: diff={result_info['diff']}; eps={eps}; bsz={bsz}" print(colored(msg, "yellow")) y.backward(torch.zeros_like(dl_df_est), retain_graph=False) grad_args = [None for _ in range(len(args))] return (None, dl_df_est, None, None, *grad_args)
def broyden_find_root(func, z1, u, eps, *args): bsz = z1[0].size(0) z1_est = DEQFunc2d.list2vec(z1) cutoffs = [(elem.size(1), elem.size(2), elem.size(3)) for elem in z1] threshold, train_step, writer = args[-3:] g = lambda x: DEQFunc2d.g(func, x, u, cutoffs, *args) result_info = broyden(g, z1_est, threshold=threshold, eps=eps, name="forward") z1_est = result_info['result'] nstep = result_info['nstep'] lowest_step = result_info['lowest_step'] diff = result_info['diff'] r_diff = min(result_info['new_trace'][1:]) if z1_est.get_device() == 0: if writer is not None: writer.add_scalar('forward/diff', result_info['diff'], train_step) writer.add_scalar('forward/nstep', result_info['nstep'], train_step) writer.add_scalar('forward/lowest_step', result_info['lowest_step'], train_step) writer.add_scalar('forward/final_trace', result_info['new_trace'][lowest_step], train_step) status = analyze_broyden(result_info, judge=True) if status: err = {"z1": z1} analyze_broyden(result_info, err=err, judge=False, name="forward", save_err=False) if threshold > 30: torch.cuda.empty_cache() return DEQFunc2d.vec2list(z1_est.clone().detach(), cutoffs)
def backward(ctx, grad): torch.cuda.empty_cache() # grad should have dimension (bsz x d_model x seq_len) bsz, d_model, seq_len = grad.size() grad = grad.clone() z1ss, uss, z0 = ctx.saved_tensors args = ctx.args threshold, train_step = args[-2:] func = ctx.func z1ss = z1ss.clone().detach().requires_grad_() uss = uss.clone().detach() z0 = z0.clone().detach() with torch.enable_grad(): y = RootFind.g(func, z1ss, uss, z0, *args) def g(x): y.backward(x, retain_graph=True) # Retain for future calls to g JTx = z1ss.grad.clone().detach() z1ss.grad.zero_() return JTx + grad eps = 2e-10 * np.sqrt(bsz * seq_len * d_model) dl_df_est = torch.zeros_like(grad) result_info = broyden(g, dl_df_est, threshold=threshold, eps=eps, name="backward") dl_df_est = result_info['result'] y.backward(torch.zeros_like(dl_df_est), retain_graph=False) grad_args = [None for _ in range(len(args))] return (None, dl_df_est, None, None, *grad_args)
def broyden_find_root(func, z1ss, uss, z0, eps, *args): bsz, d_model, seq_len = z1ss.size() z1ss_est = z1ss.clone().detach() threshold = args[ -2] # Can also set this to be different, based on training/inference train_step = args[-1] g = lambda x: RootFind.g(func, x, uss, z0, *args) result_info = broyden(g, z1ss_est, threshold=threshold, eps=eps, name="forward") z1ss_est = result_info['result'] nstep = result_info['nstep'] if z1ss_est.get_device() == 0 and np.random.uniform(0, 1) < 1e-4: msg = f"{nstep} steps in Broyden forward: diff={result_info['diff']}; eps={eps}; bsz={bsz}" print(colored(msg, "cyan")) if threshold > 100: torch.cuda.empty_cache() return z1ss_est.clone().detach()
def broyden_find_root(func, z1ss, uss, z0, eps, *args): bsz, d_model, seq_len = z1ss.size() z1ss_est = z1ss.clone().detach() threshold = args[ -2] # Can also set this to be different, based on training/inference train_step = args[-1] g = lambda x: DEQFunc.g(func, x, uss, z0, *args) result_info = broyden(g, z1ss_est, threshold=threshold, eps=eps) g_f_x = torch.zeros_like(z1ss_est) z1ss_est = result_info['result'] nstep = result_info['nstep'] # \nabla calc ================================================= # z1ss_est_temp = z1ss.clone().detach().requires_grad_() # func_copy = func # #func_copy = copy.deepcopy(func_copy) # with torch.enable_grad(): # y = DEQFunc.f(func_copy, z1ss_est_temp, uss, z0, *args) # def grad_f_x(x): # y.backward(x, retain_graph=True) # Retain for future calls to g # JTx = z1ss_est_temp.grad.clone().detach() # z1ss_est_temp.grad.zero_() # return JTx # g_f_x = grad_f_x(z1ss_est) # ============================================================= if threshold > 100: torch.cuda.empty_cache() return z1ss_est.clone().detach(), g_f_x
def backward(ctx, grad): # grad should have dimension (bsz x d_model x seq_len) bsz, d_model, seq_len = grad.size() grad = grad.clone() z1, = ctx.saved_tensors u = ctx.u factor = sum(ue.nelement() for ue in u) // z1.nelement() cutoffs = [(elem.size(1) // factor, elem.size(2), elem.size(3)) for elem in u] args = ctx.args threshold, train_step, writer = args[-3:] func = ctx.func z1_temp = z1.clone().detach().requires_grad_() u_temp = [elem.clone().detach() for elem in u] args_temp = args[:-1] # ''' # Calculate dF/dx # ''' # with torch.enable_grad(): # f_x = DEQFunc2d.f_x(func, z1_temp, u_temp, cutoffs, *args) # def f(x): # f_x.backward(x, retain_graph=True) # df_dx = z1_temp.grad.clone() # z1_temp.grad.zero_() # return df_dx # # Here is your grad_f_x # df_dx = f(z1_temp) # print('df_dx_norm: {}'.format(torch.norm(df_dx))) ''' Calculate dL/df_est ''' with torch.enable_grad(): y = DEQFunc2d.g(func, z1_temp, u_temp, cutoffs, *args_temp) def g(x): y.backward(x, retain_graph=True) # Retain for future calls to g #print(torch.norm(x)) #print(torch.norm(z1_temp.grad)) res = z1_temp.grad + grad z1_temp.grad.zero_() return res eps = 2e-10 * np.sqrt(bsz * seq_len * d_model) dl_df_est = torch.zeros_like(grad) #result_info = broyden(g, dl_df_est, threshold=threshold, eps=eps, name="backward") result_info = broyden(g, dl_df_est, threshold=1, eps=eps, name="backward") dl_df_est = result_info['result'] nstep = result_info['nstep'] lowest_step = result_info['lowest_step'] if dl_df_est.get_device() == 0: if writer is not None: writer.add_scalar('backward/diff', result_info['diff'], train_step) #writer.add_scalar('backward/torch.norm(df_dx)', torch.norm(df_dx), train_step) writer.add_scalar('backward/nstep', result_info['nstep'], train_step) writer.add_scalar('backward/lowest_step', result_info['lowest_step'], train_step) writer.add_scalar('backward/final_trace', result_info['new_trace'][lowest_step], train_step) status = analyze_broyden(result_info, judge=True) if status: err = {"z1": z1} analyze_broyden(result_info, err=err, judge=False, name="backward", save_err=False) if threshold > 30: torch.cuda.empty_cache() # Delete graph y.backward(torch.zeros_like(dl_df_est), retain_graph=False) grad_args = [None for _ in range(len(args))] return (None, dl_df_est, None, *grad_args)