# Starting conditions for number of infected
def getU_0(x):
    return np.exp(0 * a) * (x)


def fDeriv(u):
    return a * u

fig, (ax1, ax2) = plt.subplots(1,2, figsize=(13,4))


for k in kListSmall:
    diffList = []
    for N in NList:
        shapeObject = Shape(dim=dim, N=N)
        # Inserting into MCN object that solves the problem. Note neumann conditions are
        UList, times = ModifiedCrankNicolson.modifiedCrankNicolson(getU0=getU_0, fDeriv=fDeriv,
                                                                   scheme=schemes.makeLaplacian1D(1),
                                                                   shape=shapeObject, T=T, k=k, g=U_exact)
        t, x = np.ogrid[0:T+k:k,0:1+1/N-0.0001:1/N]
        Uex = U_exact(x,t)
        diffList.append(np.nanmax(abs(UList-Uex)))
    ax1.plot(NList,diffList,label=f"$k={k}$")

ax1.set_xlabel("$N$")
ax1.set_ylabel("Error")
ax1.legend()

for k in kListLarge:
    diffList = []
    def __init__(self,
                 getU0_S,
                 getU0_I,
                 muS,
                 muI,
                 schemeS,
                 getBeta,
                 getGamma,
                 T=10,
                 k=1,
                 N=4,
                 isBoundaryFunction=None,
                 dim=2,
                 length=1,
                 origin=0,
                 isNeumannFunc=None,
                 schemeNeumannFunc=None,
                 g=gDefault):
        """
        :param scheme: function returning array of coefficients.
        :param f: function returning time derivative.
        :param g: function  returning boundry conditions.
        :param isBoundaryFunction: return true if point is on boundary
        :param isBoundaryFunction:
        :param length: length of sides
        :param origin: for plotting
        :param isNeumannFunc: Function Returning true if point has Neumann conditions
        :param schemeNeumannFunc: Scheme for Neumann conditions on that point.
        """
        self.T, self.k = T, k
        self.shapeObject = Shape(N=N,
                                 isBoundaryFunc=isBoundaryFunction,
                                 dim=dim,
                                 length=length,
                                 origin=origin)

        AmatrixS, FinternalS, FboundaryS, self.geometryS = FiniteDifference.getSystem(
            shape=self.shapeObject,
            f=fZero,
            g=gOne,
            scheme=schemeS,
            isNeumannFunc=isNeumannFunc,
            schemeNeumannFunc=schemeNeumannFunc,
            interior=True)

        self.geometryI = self.geometryS
        self.Fboundary = np.concatenate((muS * FboundaryS, muI * FboundaryS))

        self.diffOperator = sparse.bmat(
            [[AmatrixS * muS, None], [None, AmatrixS * muI]], format="csc")

        geometrySI = np.concatenate((self.geometryS, self.geometryI), axis=1)
        domainGeometrySI = FiniteDifference.getDomainGeometry(geometrySI)
        self.domainSize = len(domainGeometrySI[0])

        # Assuming R = 0 at t = 0
        if self.shapeObject.dim == 1:
            I_0 = Shape.getVectorLattice(
                getU0_I(Shape.getMeshGrid(self.shapeObject)), self.shapeObject)
            S_0 = Shape.getVectorLattice(
                getU0_S(Shape.getMeshGrid(self.shapeObject)), self.shapeObject)
        else:
            I_0 = Shape.getVectorLattice(
                getU0_I(*Shape.getMeshGrid(self.shapeObject)),
                self.shapeObject)
            S_0 = Shape.getVectorLattice(
                getU0_S(*Shape.getMeshGrid(self.shapeObject)),
                self.shapeObject)
        self.U_0 = np.concatenate(
            (S_0, I_0))[np.logical_or(geometrySI[0], geometrySI[1])]

        diseaseModelF = DiseaseModel.getDiseaseModelF(getBeta, getGamma,
                                                      self.shapeObject)

        self.UList, self.times = ModifiedCrankNicolson.modifiedCrankNicolsonSolver(
            self.U_0,
            f=diseaseModelF,
            T=T,
            k=k,
            diffOperator=self.diffOperator,
            boundaryCoeffs=-self.Fboundary,
            geometry=geometrySI,
            g=gDefault)