Exemplo n.º 1
0
def solve_newton_with_dt(func, u0, args, kargs, dt,
          max_iter, abs_tol, rel_tol, verbose):
    u = adarray(value(u0).copy())
    _DEBUG_perturb_new(u)
    for i_Newton in range(max_iter):
        if dt == np.inf:
            res = func(u, *args, **kargs)
        else:
            res = (u - u0) / dt + func(u, *args, **kargs)
        res_norm = np.linalg.norm(res._value.reshape(res.size), np.inf)
        if verbose:
            print('    ', i_Newton, res_norm)

        if i_Newton == 0:
            res_norm0 = res_norm
        if res_norm < max(abs_tol, rel_tol * res_norm0):
            return adsolution(u, res, i_Newton + 1)
        if not np.isfinite(res_norm) or res_norm > res_norm0 * 1E6:
            break

        # Newton update
        J = res.diff(u).tocsr()
        if J.shape[0] > 1:
            minus_du = splinalg.spsolve(J, np.ravel(res._value),
                                        use_umfpack=False)
        else:
            minus_du = res._value / J.toarray()[0,0]
        u._value -= minus_du.reshape(u.shape)
        u = adarray(u._value)  # unlink operation history if any
        _DEBUG_perturb_new(u)

    # not converged
    u = adarray(value(u0).copy())
    res = func(u, *args, **kargs)
    return adsolution(u, res, np.inf)
Exemplo n.º 2
0
def solve(func, u0, args=(), kargs={},
          max_iter=10, abs_tol=1E-6, rel_tol=1E-6, verbose=True):
    u = adarray(value(u0).copy())
    _DEBUG_perturb_new(u)

    func = replace__globals__(func)
    for i_Newton in range(max_iter):
        res = func(u, *args, **kargs)
        res_norm = np.linalg.norm(res._value, np.inf)
        if verbose:
            print('    ', i_Newton, res_norm)
        if not np.isfinite(res_norm):
            break

        if i_Newton == 0:
            res_norm0 = res_norm
        if res_norm < max(abs_tol, rel_tol * res_norm0):
            return adsolution(u, res, i_Newton + 1)

        # Newton update
        J = res.diff(u).tocsr()
        if J.shape[0] > 1:
            minus_du = splinalg.spsolve(J, np.ravel(res._value),
                                        use_umfpack=False)
        else:
            minus_du = res._value / J.toarray()[0,0]
        u._value -= minus_du.reshape(u.shape)
        u = adarray(u._value)  # unlink operation history if any
        _DEBUG_perturb_new(u)
    # not converged
    return adsolution(u, res, np.inf)
Exemplo n.º 3
0
    def testPoisson2d(self):
        #N, M = 256, 512
        N, M = 256, 64
        dx, dy = adarray([1. / N, 1. / M])

        f = ones((N-1, M-1))
        u = ones((N-1, M-1))

        u = solve(self.residual, u, (f, dx, dy))

        x = np.linspace(0, 1, N+1)[1:-1]
        y = np.linspace(0, 1, M+1)[1:-1]

        # solve tangent equation
        dudx = np.array(u.diff(dx).todense()).reshape(u.shape)
        dudy = np.array(u.diff(dy).todense()).reshape(u.shape)

        self.assertAlmostEqual(0,
            abs(2 * u._value - (dudx * dx._value + dudy * dy._value)).max())

        # solve adjoint equation
        J = u.sum()
        dJdf = J.diff(f)

        self.assertAlmostEqual(0, abs(np.ravel(u._value) - dJdf).max())
Exemplo n.º 4
0
def spsolve(A, b):
    """
    AD equivalence of scipy.sparse.linalg.spsolve.
    """
    x = adarray(sp.linalg.spsolve(A._value.tocsr(), b._value))
    r = A * x - b
    return adsolution(x, r, 1)
Exemplo n.º 5
0
def solve(A, b):
    '''
    AD equivalence of linalg.solve
    '''
    assert A.ndim == 2 and b.shape[0] == A.shape[0]
    x = adarray(np.linalg.solve(value(A), value(b)))
    r = dot(A, x) - b
    return adsolution(x, r, 1)
Exemplo n.º 6
0
def solve(func, u0, args=(), kargs={},
          max_iter=10, abs_tol=1E-6, rel_tol=1E-6, verbose=True):
    u = adarray(base(u0).copy())
    _DEBUG_perturb_new(u)

    for i_Newton in range(max_iter):
        start = time.time()
        res = func(u, *args, **kargs)  # TODO: how to put into adarray context?
        res_norm = np.linalg.norm(res._base, np.inf)
        if verbose:
            print('    ', i_Newton, res_norm)
        if not np.isfinite(res_norm):
            break

        if i_Newton == 0:
            res_norm0 = res_norm
        if res_norm < max(abs_tol, rel_tol * res_norm0):
            return adsolution(u, res, i_Newton + 1)

        # Newton update
        J = res.diff(u).tocsr()
        start2 = time.time()
        minus_du = splinalg.spsolve(J, np.ravel(res._base))
        print time.time()-start2

#        P = splinalg.spilu(J, drop_tol=1e-5)
#        M_x = lambda x: P.solve(x)
#        M = splinalg.LinearOperator((n * m, n * m), M_x)
#        minus_du = splinalg.gmres(J, np.ravel(res._base), M=M,tol=1e-6)

        u._base -= minus_du.reshape(u.shape)
        u = adarray(u._base)  # unlink operation history if any
        _DEBUG_perturb_new(u)

        print time.time()-start
    # not converged
    return adsolution(u, res, np.inf)
Exemplo n.º 7
0
    def __mul__(self, b):
        """
        Only implemented for a single vector b
        """
        if b.ndim == 1:
            A_x_b = adarray(self._value * b._value)
            A_x_b.next_state(self._value, b, "*")

            data_multiplier = sp.csr_matrix((b._value[self.j], (self.i, np.arange(self.data.size))))
            A_x_b.next_state(data_multiplier, self.data, "*")
            return A_x_b
        else:
            shape = b.shape[1:]
            b = b.reshape([b.shape[0], -1])
            a = transpose([self * bi for bi in b.T])
            return a.reshape((a.shape[0],) + shape)
Exemplo n.º 8
0
    def testPoisson1d(self):
        N = 256
        dx = adarray(1. / N)

        f = ones(N-1)
        u = zeros(N-1)

        u = solve(self.residual, u, (f, dx))

        x = np.linspace(0, 1, N+1)[1:-1]
        self.assertAlmostEqual(0, np.abs(u._value - 0.5 * x * (1 - x)).max())

        # solve tangent equation
        dudx = np.array(u.diff(dx).todense()).reshape(u.shape)
        self.assertAlmostEqual(0, np.abs(dudx - 2 * u._value / dx._value).max())

        # solve adjoint equation
        J = u.sum()
        dJdf = J.diff(f)
        self.assertAlmostEqual(0, np.abs(dJdf - u._value).max())