def newton_linear(P0, P1, P2, f, alpha, PSring, prec):
    assert P0.base_ring() is P1.base_ring()
    assert P1.base_ring() is P2.base_ring()
    assert P0.base_ring() is alpha.base_ring()
    assert P0.base_ring() is PSring.base_ring()
    assert f.base_ring() in [QQ, ZZ, P0.base_ring()]
    T = PSring.gen()
    #xi_0 = xi(0)
    x0_0 = PSring(P0[0])
    x1_0 = PSring(P1[0])
    x2_0 = PSring(P2[0])

    y0_0 = P0[1]
    y1_0 = P1[1]
    y2_0 = P2[1]

    b0, b1, b2 = 1, 1, 1

    py0_0 = sqrt(f(x0_0 + T)).list()[0]
    py1_0 = sqrt(f(x1_0 + T)).list()[0]
    py2_0 = sqrt(f(x2_0 + T)).list()[0]

    if PSring.base_ring().is_exact():
        if y0_0 != py0_0:
            b0 = -1
        if y1_0 != py1_0:
            b1 = -1
        if y2_0 != py2_0:
            b2 = -1
        assert y0_0 == b0 * py0_0, "wrong branch at P0? %s != %s" % (y0_0, b0 *
                                                                     py0_0)
        assert y1_0 == b1 * py1_0, "wrong branch at P1? %s != %s" % (y1_0, b1 *
                                                                     py1_0)
        assert y2_0 == b2 * py2_0, "wrong branch at P2? %s != %s" % (y2_0, b2 *
                                                                     py2_0)
    else:
        if norm(y0_0 - py0_0) > norm(y0_0 + py0_0):
            b0 = -1
        if norm(y1_0 - py1_0) > norm(y1_0 + py1_0):
            b1 = -1
        if norm(y2_0 - py2_0) > norm(y2_0 + py2_0):
            b2 = -1
        closetozero = 2**(-0.8 * PSring.base_ring().prec())
        assert norm(y0_0 +
                    b0 * py0_0) > norm(y0_0), "%.3e vs %.3e wrong branch?" % (
                        norm(y0_0 - b0 * py0_0), norm(y0_0 + b0 * py0_0))
        assert norm(y0_0 - b0 * py0_0) < norm(
            y0_0) * closetozero, "%.3e vs %.3e wrong branch?" % (
                norm(y0_0 - b0 * py0_0), norm(y0_0 + b0 * py0_0))

        assert norm(y1_0 +
                    b1 * py1_0) > norm(y1_0), "%.3e vs %.3e wrong branch?" % (
                        norm(y1_0 - b1 * py1_0), norm(y1_0 + b1 * py1_0))
        assert norm(y1_0 - b1 * py1_0) < norm(
            y1_0) * closetozero, "%.3e vs %.3e wrong branch?" % (
                norm(y1_0 - b1 * py1_0), norm(y1_0 + b1 * py1_0))

        assert norm(y2_0 +
                    b2 * py2_0) > norm(y2_0), "%.3e vs %.3e wrong branch?" % (
                        norm(y2_0 - b2 * py2_0), norm(y2_0 + b2 * py2_0))
        assert norm(y2_0 - b2 * py2_0) < norm(
            y2_0) * closetozero, "%.3e vs %.3e wrong branch?" % (
                norm(y2_0 - b2 * py2_0), norm(y2_0 + b2 * py2_0))

    x1 = x1_0 + O(T)
    x2 = x2_0 + O(T)

    p = 1

    Tsub = x0_0 + T
    fps = PSring(f)(Tsub)
    sqrtfps = sqrt(fps)

    row0 = PSring(alpha.row(0).list())(Tsub)
    row1 = PSring(alpha.row(1).list())(Tsub)

    if x1_0 != x2_0:
        # Iteration to solve the differential equation, gains 1 term
        # alternatively we could solve a linear ODE to double the number of terms
        while p < prec:
            m = b0 / ((x2 - x1) * sqrtfps)
            dx1 = m * b1 * sqrt(f(x1)) * (row0 * x2 - row1)
            dx2 = m * b2 * sqrt(f(x2)) * (row1 - row0 * x1)
            p += 1
            x1 = x1_0 + dx1.integral() + O(T**p)
            x2 = x2_0 + dx2.integral() + O(T**p)
    else:
        #the degenerate case
        row0_y = row0 * b0 / sqrtfps
        half = P0.base_ring()(1 / 2)
        while p < prec:
            dx1 = half * b1 * sqrt(f(x1)) * row0_y
            x1 = x1_0 + dx1.integral() + O(T**p)
            p += 1
        # check second equation
        if 2 * x1 * x1.derivative() * b1 / sqrt(f(x1)) != row1 * b0 / sqrtfps:
            print "x1(0) = x2(0), but x1 != x2"
            raise ZeroDivisionError

        x2 = x1

    return x1, x2