Esempio n. 1
0
File: deq.py Progetto: ryoungj/deq
    def broyden_find_root(func, z1ss, uss, z0, eps, *args):
        bsz, d_model, seq_len = z1ss.size()
        z1ss_est = z1ss.clone().detach()
        threshold = args[
            -2]  # Can also set this to be different, based on training/inference
        train_step = args[-1]

        g = lambda x: DEQFunc.g(func, x, uss, z0, *args)
        result_info = broyden(g,
                              z1ss_est,
                              threshold=threshold,
                              eps=eps,
                              name="forward")
        z1ss_est = result_info['result']
        nstep = result_info['nstep']

        DEQFunc.forward_step = result_info['nstep']
        DEQFunc.forward_reduced_ratio = result_info['reduced_ratio']
        DEQFunc.forward_lowest = result_info['diff']
        DEQFunc.forward_init = result_info['init']
        DEQFunc.forward_obj_trace = result_info['trace']

        if threshold > 100:
            torch.cuda.empty_cache()
        return z1ss_est.clone().detach()
Esempio n. 2
0
File: deq.py Progetto: ryoungj/deq
    def backward(ctx, grad):
        torch.cuda.empty_cache()

        # grad should have dimension (bsz x d_model x seq_len)
        bsz, d_model, seq_len = grad.size()
        grad = grad.clone()
        z1ss, uss, z0 = ctx.saved_tensors
        args = ctx.args
        threshold = args[-2]
        train_step = args[-1]

        func = ctx.func
        z1ss_temp = z1ss.clone().detach().requires_grad_()
        uss_temp = uss.clone().detach()
        z0_temp = z0.clone().detach()
        args_temp = copy.deepcopy(args)

        with torch.enable_grad():
            y = DEQFunc.g(func, z1ss_temp, uss_temp, z0_temp, *args_temp)

        def g(x):
            y.backward(x, retain_graph=True)  # Retain for future calls to g
            JTx = z1ss_temp.grad.clone().detach()
            z1ss_temp.grad.zero_()
            return JTx + grad

        eps = 2e-10 * np.sqrt(bsz * seq_len * d_model)
        dl_df_est = torch.zeros_like(grad)

        result_info = broyden(g,
                              dl_df_est,
                              threshold=threshold,
                              eps=eps,
                              name="backward")
        dl_df_est = result_info['result']
        nstep = result_info['nstep']
        DummyDEQFunc.backward_step = result_info['nstep']
        DummyDEQFunc.backward_reduced_ratio = result_info['reduced_ratio']
        DummyDEQFunc.backward_lowest = result_info['diff']
        DummyDEQFunc.backward_init = result_info['init']
        DummyDEQFunc.backward_obj_trace = result_info['trace']

        y.backward(torch.zeros_like(dl_df_est), retain_graph=False)

        grad_args = [None for _ in range(len(args))]
        return (None, dl_df_est, None, None, *grad_args)
Esempio n. 3
0
    def broyden_find_root(func, z1ss, uss, z0, eps, *args):
        bsz, d_model, seq_len = z1ss.size()
        z1ss_est = z1ss.clone().detach()
        threshold = args[
            -2]  # Can also set this to be different, based on training/inference
        train_step = args[-1]

        g = lambda x: RootFind.g(func, x, uss, z0, *args)
        result_info = broyden(g,
                              z1ss_est,
                              threshold=threshold,
                              eps=eps,
                              name="forward")
        z1ss_est = result_info['result']

        if threshold > 100:
            torch.cuda.empty_cache()
        return z1ss_est.clone().detach()
Esempio n. 4
0
        def backward(ctx, grad):
            torch.cuda.empty_cache()

            # grad should have dimension (bsz x d_model x seq_len)
            bsz, d_model, seq_len = grad.size()
            grad = grad.clone()
            z1ss, uss, z0 = ctx.saved_tensors
            args = ctx.args
            threshold, train_step = args[-2:]

            func = ctx.func
            z1ss = z1ss.clone().detach().requires_grad_()
            uss = uss.clone().detach()
            z0 = z0.clone().detach()

            with torch.enable_grad():
                y = RootFind.g(func, z1ss, uss, z0, *args)

            def g(x):
                y.backward(x,
                           retain_graph=True)  # Retain for future calls to g
                JTx = z1ss.grad.clone().detach()
                z1ss.grad.zero_()
                return JTx + grad

            eps = 2e-10 * np.sqrt(bsz * seq_len * d_model)
            dl_df_est = torch.zeros_like(grad)

            result_info = broyden(g,
                                  dl_df_est,
                                  threshold=threshold,
                                  eps=eps,
                                  name="backward")
            dl_df_est = result_info['result']
            nstep = result_info['nstep']

            if dl_df_est.get_device() == 0 and np.random.uniform(0, 1) < 1e-4:
                msg = f"{nstep} steps in Broyden backward: diff={result_info['diff']}; eps={eps}; bsz={bsz}"
                print(colored(msg, "yellow"))

            y.backward(torch.zeros_like(dl_df_est), retain_graph=False)

            grad_args = [None for _ in range(len(args))]
            return (None, dl_df_est, None, None, *grad_args)
Esempio n. 5
0
    def broyden_find_root(func, z1, u, eps, *args):
        bsz = z1[0].size(0)
        z1_est = DEQFunc2d.list2vec(z1)
        cutoffs = [(elem.size(1), elem.size(2), elem.size(3)) for elem in z1]
        threshold, train_step, writer = args[-3:]

        g = lambda x: DEQFunc2d.g(func, x, u, cutoffs, *args)
        result_info = broyden(g,
                              z1_est,
                              threshold=threshold,
                              eps=eps,
                              name="forward")
        z1_est = result_info['result']
        nstep = result_info['nstep']
        lowest_step = result_info['lowest_step']
        diff = result_info['diff']
        r_diff = min(result_info['new_trace'][1:])

        if z1_est.get_device() == 0:
            if writer is not None:
                writer.add_scalar('forward/diff', result_info['diff'],
                                  train_step)
                writer.add_scalar('forward/nstep', result_info['nstep'],
                                  train_step)
                writer.add_scalar('forward/lowest_step',
                                  result_info['lowest_step'], train_step)
                writer.add_scalar('forward/final_trace',
                                  result_info['new_trace'][lowest_step],
                                  train_step)

        status = analyze_broyden(result_info, judge=True)
        if status:
            err = {"z1": z1}
            analyze_broyden(result_info,
                            err=err,
                            judge=False,
                            name="forward",
                            save_err=False)

        if threshold > 30:
            torch.cuda.empty_cache()
        return DEQFunc2d.vec2list(z1_est.clone().detach(), cutoffs)
Esempio n. 6
0
        def backward(ctx, grad):
            torch.cuda.empty_cache()

            # grad should have dimension (bsz x d_model x seq_len)
            bsz, d_model, seq_len = grad.size()
            grad = grad.clone()
            z1ss, uss, z0 = ctx.saved_tensors
            args = ctx.args
            threshold, train_step = args[-2:]

            func = ctx.func
            z1ss = z1ss.clone().detach().requires_grad_()
            uss = uss.clone().detach()
            z0 = z0.clone().detach()

            with torch.enable_grad():
                y = RootFind.g(func, z1ss, uss, z0, *args)

            def g(x):
                y.backward(x,
                           retain_graph=True)  # Retain for future calls to g
                JTx = z1ss.grad.clone().detach()
                z1ss.grad.zero_()
                return JTx + grad

            eps = 2e-10 * np.sqrt(bsz * seq_len * d_model)
            dl_df_est = torch.zeros_like(grad)

            result_info = broyden(g,
                                  dl_df_est,
                                  threshold=threshold,
                                  eps=eps,
                                  name="backward")
            dl_df_est = result_info['result']

            y.backward(torch.zeros_like(dl_df_est), retain_graph=False)

            grad_args = [None for _ in range(len(args))]
            return (None, dl_df_est, None, None, *grad_args)
Esempio n. 7
0
    def broyden_find_root(func, z1ss, uss, z0, eps, *args):
        bsz, d_model, seq_len = z1ss.size()
        z1ss_est = z1ss.clone().detach()
        threshold = args[
            -2]  # Can also set this to be different, based on training/inference
        train_step = args[-1]

        g = lambda x: RootFind.g(func, x, uss, z0, *args)
        result_info = broyden(g,
                              z1ss_est,
                              threshold=threshold,
                              eps=eps,
                              name="forward")
        z1ss_est = result_info['result']
        nstep = result_info['nstep']

        if z1ss_est.get_device() == 0 and np.random.uniform(0, 1) < 1e-4:
            msg = f"{nstep} steps in Broyden forward: diff={result_info['diff']}; eps={eps}; bsz={bsz}"
            print(colored(msg, "cyan"))

        if threshold > 100:
            torch.cuda.empty_cache()
        return z1ss_est.clone().detach()
Esempio n. 8
0
    def broyden_find_root(func, z1ss, uss, z0, eps, *args):
        bsz, d_model, seq_len = z1ss.size()
        z1ss_est = z1ss.clone().detach()
        threshold = args[
            -2]  # Can also set this to be different, based on training/inference
        train_step = args[-1]

        g = lambda x: DEQFunc.g(func, x, uss, z0, *args)
        result_info = broyden(g, z1ss_est, threshold=threshold, eps=eps)

        g_f_x = torch.zeros_like(z1ss_est)

        z1ss_est = result_info['result']
        nstep = result_info['nstep']

        # \nabla calc =================================================
        # z1ss_est_temp = z1ss.clone().detach().requires_grad_()
        # func_copy = func
        # #func_copy = copy.deepcopy(func_copy)

        # with torch.enable_grad():
        #     y = DEQFunc.f(func_copy, z1ss_est_temp, uss, z0, *args)

        # def grad_f_x(x):
        #    y.backward(x, retain_graph=True)   # Retain for future calls to g
        #    JTx = z1ss_est_temp.grad.clone().detach()
        #    z1ss_est_temp.grad.zero_()
        #    return JTx

        # g_f_x = grad_f_x(z1ss_est)
        # =============================================================

        if threshold > 100:
            torch.cuda.empty_cache()

        return z1ss_est.clone().detach(), g_f_x
Esempio n. 9
0
        def backward(ctx, grad):
            # grad should have dimension (bsz x d_model x seq_len)
            bsz, d_model, seq_len = grad.size()
            grad = grad.clone()
            z1, = ctx.saved_tensors
            u = ctx.u
            factor = sum(ue.nelement() for ue in u) // z1.nelement()
            cutoffs = [(elem.size(1) // factor, elem.size(2), elem.size(3)) for elem in u]
            args = ctx.args
            threshold, train_step, writer = args[-3:]

            func = ctx.func
            z1_temp = z1.clone().detach().requires_grad_()
            u_temp = [elem.clone().detach() for elem in u]
            args_temp = args[:-1]            

            # '''
            # Calculate dF/dx
            # '''
            # with torch.enable_grad():
            #     f_x = DEQFunc2d.f_x(func, z1_temp, u_temp, cutoffs, *args)

            # def f(x):
            #     f_x.backward(x, retain_graph=True)
            #     df_dx = z1_temp.grad.clone()
            #     z1_temp.grad.zero_()
            #     return df_dx

            # # Here is your grad_f_x
            # df_dx = f(z1_temp)
            # print('df_dx_norm: {}'.format(torch.norm(df_dx)))

            '''
            Calculate dL/df_est
            '''

            with torch.enable_grad():
                y = DEQFunc2d.g(func, z1_temp, u_temp, cutoffs, *args_temp)

            def g(x):
                y.backward(x, retain_graph=True)  # Retain for future calls to g
                #print(torch.norm(x))
                #print(torch.norm(z1_temp.grad))
                res = z1_temp.grad + grad
                z1_temp.grad.zero_()
                return res

            eps = 2e-10 * np.sqrt(bsz * seq_len * d_model)
            dl_df_est = torch.zeros_like(grad)

            #result_info = broyden(g, dl_df_est, threshold=threshold, eps=eps, name="backward")
            result_info = broyden(g, dl_df_est, threshold=1, eps=eps, name="backward")
            dl_df_est = result_info['result']
            nstep = result_info['nstep']
            lowest_step = result_info['lowest_step']
            
            if dl_df_est.get_device() == 0:
                if writer is not None:
                    writer.add_scalar('backward/diff', result_info['diff'], train_step)
                    #writer.add_scalar('backward/torch.norm(df_dx)', torch.norm(df_dx), train_step)
                    writer.add_scalar('backward/nstep', result_info['nstep'], train_step)
                    writer.add_scalar('backward/lowest_step', result_info['lowest_step'], train_step)
                    writer.add_scalar('backward/final_trace', result_info['new_trace'][lowest_step], train_step)

            status = analyze_broyden(result_info, judge=True)
            if status:
                err = {"z1": z1}
                analyze_broyden(result_info, err=err, judge=False, name="backward", save_err=False)

            if threshold > 30:
                torch.cuda.empty_cache()

            # Delete graph
            y.backward(torch.zeros_like(dl_df_est), retain_graph=False)

            grad_args = [None for _ in range(len(args))]

            return (None, dl_df_est, None, *grad_args)