def modifiedCrankNicolson(getU0=None,
                              fDeriv=fZero,
                              scheme=None,
                              T=10,
                              k=1,
                              shape=None,
                              isNeumannFunc=None,
                              schemeNeumannFunc=None,
                              g=gDefault):
        Amatrix, Finterior, Fboundary, geometry = FiniteDifference.getSystem(
            shape=shape,
            f=fZero,
            g=gOne,
            scheme=scheme,
            isNeumannFunc=isNeumannFunc,
            schemeNeumannFunc=schemeNeumannFunc,
            interior=True)
        if shape.dim == 1:

            def getBoundaryVals(t):
                return Shape.getVectorLattice(g(Shape.getMeshGrid(shape), t),
                                              shape)

            U_0 = Shape.getVectorLattice(getU0(Shape.getMeshGrid(shape)),
                                         shape)

        else:

            def getBoundaryVals(t):
                return Shape.getVectorLattice(g(*Shape.getMeshGrid(shape), t),
                                              shape)

            U_0 = Shape.getVectorLattice(getU0(*Shape.getMeshGrid(shape)),
                                         shape)

        return ModifiedCrankNicolson.modifiedCrankNicolsonSolver(
            U_0,
            fDeriv,
            T=T,
            k=k,
            diffOperator=Amatrix,
            boundaryCoeffs=-Fboundary,
            g=getBoundaryVals,
            geometry=geometry)
    def applyDiseaseAnimation(self, axS, axI, animationLength=10, **kvargs):
        artistlist = []
        fps = 24
        numFrames = animationLength * fps
        numTimes = self.UList.shape[0]
        step = numTimes // numFrames
        max = np.nanmax(self.UList)

        for timeIndex in range(0, len(self.times), step):
            S, I = self.UList[timeIndex, :self.domainSize //
                              2], self.UList[timeIndex, self.domainSize // 2:]

            artistS = Shape.plotImage2d(S,
                                        self.shapeObject,
                                        ax=axS,
                                        geometry=self.geometryS,
                                        title=f"Susceptible",
                                        animated=True,
                                        colorbar=False,
                                        **kvargs)

            artistI = Shape.plotImage2d(I,
                                        self.shapeObject,
                                        ax=axI,
                                        geometry=self.geometryI,
                                        title=f"Infected",
                                        animated=True,
                                        colorbar=False,
                                        **kvargs)

            artistlist.append([artistS, artistI])
            if timeIndex == 0:
                axI.figure.colorbar(artistI, ax=[axS, axI])

        return animation.ArtistAnimation(axI.figure,
                                         artistlist,
                                         blit=True,
                                         interval=1000 / fps)
    def plotImage(self,
                  timeIndex=0,
                  ax=None,
                  group="I",
                  show=False,
                  animated=False,
                  **kwargs):
        if group == "S":
            U = self.UList[int(timeIndex), :self.domainSize // 2]

        elif group == "I":
            U = self.UList[int(timeIndex), self.domainSize // 2:]
        else:
            print(f"Group: {group} not found")
            return None
        max = np.nanmax(U)
        return Shape.plotImage2d(U,
                                 self.shapeObject,
                                 ax=ax,
                                 show=show,
                                 geometry=self.geometryS,
                                 animated=animated,
                                 **kwargs)
 def plot(self,
          timeIndex=0,
          ax=None,
          group="I",
          show=False,
          view=None,
          **kwargs):
     if group == "S":
         U = self.UList[int(timeIndex), :self.domainSize // 2]
         title = f"Susceptible at t = {timeIndex * self.k}"
     elif group == "I":
         U = self.UList[int(timeIndex), self.domainSize // 2:]
         title = f"Infected at t = {timeIndex * self.k}"
     else:
         print(f"Group: {group} not found")
         return None
     return Shape.plotOnShape(U,
                              self.shapeObject,
                              ax=ax,
                              show=show,
                              view=view,
                              geometry=self.geometryS,
                              title=title,
                              **kwargs)
    def getDiseaseModelF(getBeta, getGamma, shape):
        # currying
        if shape.dim == 1:
            betaLattice = getBeta(Shape.getMeshGrid(shape))
            gammaLattice = getGamma(Shape.getMeshGrid(shape))
        else:
            betaLattice = getBeta(*Shape.getMeshGrid(shape))
            gammaLattice = getGamma(*Shape.getMeshGrid(shape))

        if np.size(betaLattice) == 1:
            beta = betaLattice
        else:
            beta = Shape.getVectorLattice(betaLattice, shape)

        if np.size(gammaLattice) == 1:
            gamma = gammaLattice
        else:
            gamma = Shape.getVectorLattice(gammaLattice, shape)

        def F(U):
            S, I = np.split(U, 2)
            return np.concatenate((-beta * S * I, beta * S * I - gamma * I))

        return F
# Starting conditions for number of infected
def getU_0(x):
    return np.exp(0 * a) * (x)


def fDeriv(u):
    return a * u

fig, (ax1, ax2) = plt.subplots(1,2, figsize=(13,4))


for k in kListSmall:
    diffList = []
    for N in NList:
        shapeObject = Shape(dim=dim, N=N)
        # Inserting into MCN object that solves the problem. Note neumann conditions are
        UList, times = ModifiedCrankNicolson.modifiedCrankNicolson(getU0=getU_0, fDeriv=fDeriv,
                                                                   scheme=schemes.makeLaplacian1D(1),
                                                                   shape=shapeObject, T=T, k=k, g=U_exact)
        t, x = np.ogrid[0:T+k:k,0:1+1/N-0.0001:1/N]
        Uex = U_exact(x,t)
        diffList.append(np.nanmax(abs(UList-Uex)))
    ax1.plot(NList,diffList,label=f"$k={k}$")

ax1.set_xlabel("$N$")
ax1.set_ylabel("Error")
ax1.legend()

for k in kListLarge:
    diffList = []
 def getBoundaryVals(t):
     return Shape.getVectorLattice(g(*Shape.getMeshGrid(shape), t),
                                   shape)
    def __init__(self,
                 getU0_S,
                 getU0_I,
                 muS,
                 muI,
                 schemeS,
                 getBeta,
                 getGamma,
                 T=10,
                 k=1,
                 N=4,
                 isBoundaryFunction=None,
                 dim=2,
                 length=1,
                 origin=0,
                 isNeumannFunc=None,
                 schemeNeumannFunc=None,
                 g=gDefault):
        """
        :param scheme: function returning array of coefficients.
        :param f: function returning time derivative.
        :param g: function  returning boundry conditions.
        :param isBoundaryFunction: return true if point is on boundary
        :param isBoundaryFunction:
        :param length: length of sides
        :param origin: for plotting
        :param isNeumannFunc: Function Returning true if point has Neumann conditions
        :param schemeNeumannFunc: Scheme for Neumann conditions on that point.
        """
        self.T, self.k = T, k
        self.shapeObject = Shape(N=N,
                                 isBoundaryFunc=isBoundaryFunction,
                                 dim=dim,
                                 length=length,
                                 origin=origin)

        AmatrixS, FinternalS, FboundaryS, self.geometryS = FiniteDifference.getSystem(
            shape=self.shapeObject,
            f=fZero,
            g=gOne,
            scheme=schemeS,
            isNeumannFunc=isNeumannFunc,
            schemeNeumannFunc=schemeNeumannFunc,
            interior=True)

        self.geometryI = self.geometryS
        self.Fboundary = np.concatenate((muS * FboundaryS, muI * FboundaryS))

        self.diffOperator = sparse.bmat(
            [[AmatrixS * muS, None], [None, AmatrixS * muI]], format="csc")

        geometrySI = np.concatenate((self.geometryS, self.geometryI), axis=1)
        domainGeometrySI = FiniteDifference.getDomainGeometry(geometrySI)
        self.domainSize = len(domainGeometrySI[0])

        # Assuming R = 0 at t = 0
        if self.shapeObject.dim == 1:
            I_0 = Shape.getVectorLattice(
                getU0_I(Shape.getMeshGrid(self.shapeObject)), self.shapeObject)
            S_0 = Shape.getVectorLattice(
                getU0_S(Shape.getMeshGrid(self.shapeObject)), self.shapeObject)
        else:
            I_0 = Shape.getVectorLattice(
                getU0_I(*Shape.getMeshGrid(self.shapeObject)),
                self.shapeObject)
            S_0 = Shape.getVectorLattice(
                getU0_S(*Shape.getMeshGrid(self.shapeObject)),
                self.shapeObject)
        self.U_0 = np.concatenate(
            (S_0, I_0))[np.logical_or(geometrySI[0], geometrySI[1])]

        diseaseModelF = DiseaseModel.getDiseaseModelF(getBeta, getGamma,
                                                      self.shapeObject)

        self.UList, self.times = ModifiedCrankNicolson.modifiedCrankNicolsonSolver(
            self.U_0,
            f=diseaseModelF,
            T=T,
            k=k,
            diffOperator=self.diffOperator,
            boundaryCoeffs=-self.Fboundary,
            geometry=geometrySI,
            g=gDefault)