def NetInputMap(netName,cellName,plane=None,save=False):
    #using precalculated fingerprints and cell data for everything

    #define descriptor (optional)
    fp = fingerprints(lmax=4,nmax=5,r_c=4.5)

    #setup and load network
    N = NetworkHandler(fp,name=netName)
    N.load()

    #setup the castep interface (don't really need it here, but it has everything conveniently stored)
    C = Castep_density(fp,N)
    #check to see the network was properly loaded
    C.setupNetwork()

    C.get_cell_data(cellName)

    grid = C.supercell.grid

    if (plane is None):
        plane=grid[1,2]
    else:
        idx = np.argmin(plane-grid[:,2])
        plane = grid[idx,2]

    plane_idx = np.where(plane==grid[:,2])[0]

    plane_FPs = C.supercell.FP[plane_idx,:]

    mean = N.datamean
    std = N.datastd

    plane_FPs -= mean#np.mean(plane_FPs,axis=0)

    #std = np.std(plane_FPs,axis=0)
    plane_FPs = np.where(std[None,:]>1e-5,plane_FPs/10*std[None,:],plane_FPs)
    #print(np.argmax(plane_FPs,axis=0))

    x = C.supercell.grid[plane_idx,0]
    y = C.supercell.grid[plane_idx,1]

    x,y = shiftxy(x,y)

    z = np.linalg.norm(plane_FPs,axis=1)

    cm = plt.cm.get_cmap('RdYlBu')

    sc = plt.scatter(x,y,c=z,cmap=cm,marker=',',s=s_,alpha=1)#,vmax=0.1)
    plt.colorbar(sc)
    plt.show()
    return
def std_2d(netName,cellName,plane=None,rel_std=True,rel_cap=2.0,save=False):
    #using precalculated fingerprints and cell data for everything

    #define descriptor (optional)
    fp = fingerprints(lmax=4,nmax=5,r_c=4.5)

    #setup and load network
    N = NetworkHandler(fp,name=netName)
    N.load()

    #setup the castep interface (don't really need it here, but it has everything conveniently stored)
    C = Castep_density(fp,N)
    #check to see the network was properly loaded
    C.setupNetwork()

    C.get_cell_data(cellName)

    grid = C.supercell.grid

    if (plane is None):
        plane=grid[1,2]
    else:
        idx = np.argmin(plane-grid[:,2])
        plane = grid[idx,2]

    plane_idx = np.where(plane==grid[:,2])[0]

    plane_FPs = C.supercell.FP[plane_idx,:]

    x = C.supercell.grid[plane_idx,0]
    y = C.supercell.grid[plane_idx,1]

    x,y = shiftxy(x,y)

    _, z = C.ensemble_predict(plane_FPs)
    if (rel_std):
        z =np.abs(z)/np.abs(_)
        if(rel_cap is not None):
            z = np.where(z>rel_cap,rel_cap,z)


    #X,Y,Z = transform2d(x,y,z)
    cm = plt.cm.get_cmap('RdYlBu')

    sc = plt.scatter(x,y,c=z,cmap=cm,marker=',',s=s_,alpha=1)
    plt.colorbar(sc)
    plt.show()

    return
def NetOutputMap(netName,cellName,plane=None,save=False):
    #using precalculated fingerprints and cell data for everything

    #define descriptor (optional)
    fp = fingerprints(lmax=3,nmax=3,r_c=5.5)

    #setup and load network
    N = NetworkHandler(fp,name=netName)
    N.load()

    #setup the castep interface (don't really need it here, but it has everything conveniently stored)
    C = Castep_density(fp,N)
    #check to see the network was properly loaded
    C.setupNetwork()

    C.get_cell_data(cellName)

    grid = C.supercell.grid

    if (plane is None):
        plane=grid[1,2]
    else:
        idx = np.argmin(plane-grid[:,2])
        plane = grid[idx,2]

    plane_idx = np.where(plane==grid[:,2])[0]

    plane_FPs = C.supercell.FP[plane_idx,:]
    allmean,_ = C.ensemble_predict(C.supercell.FP)

    x = C.supercell.grid[plane_idx,0]
    y = C.supercell.grid[plane_idx,1]

    x,y = shiftxy(x,y)

    z, _ = C.ensemble_predict(plane_FPs)


    #X,Y,Z = transform2d(x,y,z)
    cm = plt.cm.get_cmap('RdYlBu')

    sc = plt.scatter(x,y,c=z,cmap=cm,marker=',',s=s_,alpha=1)
    plt.colorbar(sc)
    if (save):
        plt.savefig("{}-{}_Out.pdf".format(netName,cellName))
    plt.show()
    return
def ErrorStdPlot(netname,cellname,save=False):
    #using precalculated fingerprints and cell data for everything

    #define descriptor (optional)
    fp = fingerprints(lmax=3,nmax=3,r_c=5.5)

    #setup and load network
    N = NetworkHandler(fp,name=netname)
    N.load()

    #setup the castep interface (don't really need it here, but it has everything conveniently stored)
    C = Castep_density(fp,N)
    #check to see the network was properly loaded
    C.setupNetwork()

    C.get_cell_data(cellname)

    grid = C.supercell.grid

    FPs = C.supercell.FP
    # allmean,std = C.ensemble_predict(C.supercell.FP)


    z, std_ = C.ensemble_predict(FPs)

    filename = "{}/{}".format("FP_data",cellname)
    f = open(filename,'rb')
    dict = pickle.load(f)
    f.close()
    density=dict["density"]
    z_ = density
    print(np.max(z_))

    err = np.abs(z-z_)

    plt.plot(std_,err,'b.')
    plt.xlabel("Ensemble Predicted Variance")
    plt.ylabel("Absolue Ensemble Error")
    plt.show()
    return
def NetLossPlot(netName,save=False):
    fp = fingerprints()
    N = NetworkHandler(fp,name=netName)
    N.load()
    x = np.linspace(0.0,N.nEpochs,len(N.loss[0]))
    for i, loss in enumerate(N.loss):
        plt.plot(x,loss)
    #plt.show()
    plt.close()
    x_ = np.linspace(0.0,N.nEpochs,len(N.test_rmse[0]))
    leg = []

    for i, rmse in enumerate(N.test_rmse):
        print(len(rmse))
        plt.plot(x_,rmse,colour[i])
        leg.append("Net_{}".format(i))
    plt.legend(leg)
    if (save):
        plt.savefig(NetRMSE.pdf)
    plt.show()
    plt.close()
    return
def NetDensityWrite(netName,FP_dir,Cell_dir,files=None):
    if(not os.path.isdir(FP_dir) or not os.path.isdir(Cell_dir)):
        print("invalid fingerrpint or cell directory")
        return
    #using precalculated fingerprints and cell data for everything


    if (not os.path.isdir("{}_data".format(netName))):
        os.mkdir("{}_data".format(netName))
    #define descriptor (optional)
    fp = fingerprints(lmax=4,nmax=5,r_c=4.5)

    #setup and load network
    N = NetworkHandler(fp,name=netName)
    N.load()

    #setup the castep interface (don't really need it here, but it has everything conveniently stored)
    C = Castep_density(fp,N)
    #check to see the network was properly loaded
    C.setupNetwork()

    if (files is None):
        files = os.listdir(FP_dir)
    H = []
    rmse = []
    r=[]
    mae=[]
    for i, file in enumerate(files):
        if ("0.5" in file):
            den_file = "{}_data/{}{}".format(netName,file[:-4],"initial_den")
            print("Writing density to: ", den_file)
            C.get_cell_data(file,include_density=True)
            # print(np.max(C.supercell.fin_density))
            # mean,_ = C.ensemble_predict()
            # print(np.mean(mean))
            C.setCellDensities(den_file,taper=False)

    return
def results_along_a_line(netName,cellName,pltline=None,ensembleplot=True,wDensity=True,save=False):
    #using precalculated fingerprints and cell data for everything

    #define descriptor (optional)
    fp = fingerprints(lmax=4,nmax=5,r_c=4.5)

    #setup and load network
    N = NetworkHandler(fp,name=netName)
    N.load()

    #setup the castep interface (don't really need it here, but it has everything conveniently stored)
    C = Castep_density(fp,N)
    #check to see the network was properly loaded
    C.setupNetwork()

    C.get_cell_data(cellName,include_density=True)
    print(np.mean(C.supercell.train_density))

    grid = C.supercell.grid

    if (pltline is None):
        pltline=np.zeros(2)
        pltline[0]=grid[1,0]
        pltline[1]=grid[2,0]#assumes spherical cell

    line_idx = np.where(pltline[0]==grid[:,1])[0]
    line_idx = np.where(pltline[1]==grid[line_idx,2])[0]

    line_FPs = C.supercell.FP[line_idx,:]

    x = C.supercell.grid[line_idx,0]

    if (wDensity):
        #density stored with fp data
        density = C.supercell.train_density[line_idx]

    if (ensembleplot):
        ensemble_mean,ensemble_std = C.ensemble_predict(line_FPs)
        plt.errorbar(x,ensemble_mean,yerr=ensemble_std,fmt='o')
        if (wDensity):
            plt.plot(x[:-1],density[:-1],'r--')
            plt.legend(["Density Correction","Ensemble Prediction"])
        plt.xlabel('x/$\AA$')
        plt.ylabel("Density Correction e/$\AA ^3$")
        plt.show()
    else:
        mean,std = C.NetHandler.predict(line_FPs,ensemble=False,standard=False)
        print(mean.shape)
        # mean = np.mean(mean,axis=1)
        # plt.plot(x[:-1],mean[:-1])
        nmeans=mean.shape[1]
        leg = []
        for i in range(nmeans):
            plt.plot(x[:-1],mean[:-1,i],colour[i])
            leg.append("Net_{}".format(i))
        if (wDensity):
            plt.plot(x[:-1],density[:-1],'r--')
            leg.append("Density Correction")
        plt.legend(leg)
        plt.xlabel('x/$\AA$')
        plt.ylabel("Density Correction e/$\AA ^3$")

        plt.show()
    plt.close()
    return
def VarWithR(netName,FP_dir,Cell_dir):
    if(not os.path.isdir(FP_dir) or not os.path.isdir(Cell_dir)):
        print("invalid fingerrpint or cell directory")
        return
    #using precalculated fingerprints and cell data for everything


    if (not os.path.isdir("{}_data".format(netName))):
        os.mkdir("{}_data".format(netName))
        #define descriptor (optional)
        fp = fingerprints(lmax=4,nmax=5,r_c=4.5)

        #setup and load network
        N = NetworkHandler(fp,name=netName)
        N.load()

        #setup the castep interface (don't really need it here, but it has everything conveniently stored)
        C = Castep_density(fp,N)
        #check to see the network was properly loaded
        C.setupNetwork()

        files = os.listdir(FP_dir)
        H = []
        rmse = []
        r=[]
        mae=[]
        for file in files:
            C.get_cell_data(file,include_density=True)
            mean,std = C.ensemble_predict()
            density = C.supercell.train_density
            r.append(np.linalg.norm(C.supercell.cart_coords[0,:]-C.supercell.cart_coords[1,:]))
            if (len(density)==len(mean)):
                rmse_ = np.sqrt(np.mean(np.square(density-mean)))
                rmse.append(rmse_)
                mae_ = np.mean(np.abs(density-mean))
                mae.append(mae_)
            else:
                print("Incompatible mean and density")
            H.append(C.getH(std))

        H = np.asarray(H)
        rmse = np.asarray(rmse)
        r = np.asarray(r)
        mae = np.asarray(mae)

        np.save("{}_data/r.npy".format(netName),r)
        np.save("{}_data/rmse.npy".format(netName),rmse)
        np.save("{}_data/H.npy".format(netName),H)
        np.save("{}_data/mae.npy".format(netName),mae)

    else:
        files = os.listdir(FP_dir)
        train_idx = []
        test_idx = []
        r = np.load("{}_data/r.npy".format(netName))
        rmse = np.load("{}_data/rmse.npy".format(netName))
        H = np.load("{}_data/H.npy".format(netName))
        mae = np.load("{}_data/mae.npy".format(netName))
        for i, file in enumerate(files):
            if (any((abs(trainsets[netName]-r[i]))<1e-5)):
                train_idx.append(i)
            else:
                test_idx.append(i)


    test_H = H[test_idx]
    train_H = H[train_idx]
    test_r = r[test_idx]
    train_r = r[train_idx]
    train_rmse = rmse[train_idx]
    test_rmse = rmse[test_idx]
    train_mae = mae[train_idx]
    test_mae = mae[test_idx]


    plt.plot(train_r,train_H,'bx')
    plt.plot(test_r,test_H,'rx')
    plt.legend(["Train set","Test set"])
    plt.xlabel("Interatomic Distance ($\AA$)")
    plt.ylabel("H")
    plt.show()
    plt.close()
    plt.plot(train_r,train_mae,'bx')
    plt.plot(test_r,test_mae,'rx')
    plt.legend(["Train set","Test set"])
    plt.title("Ensemble 2")
    plt.xlabel("Interatomic Distance ($\AA$)")
    plt.ylabel("Mean Absolue Error (e/$\AA^3$)")
    plt.show()
    plt.close()

    plt.plot(train_mae,train_H,'bx')
    plt.plot(test_mae,test_H,'rx')
    plt.legend(["Train set","Test set"])
    plt.ylabel("H")
    plt.xlabel("Mean Absolute Error (e/$\AA^3$)")
    plt.show()
    plt.close()

    return
def NetErrorMap(netname,cellname,plane=None,relative=True,save=False):
    #using precalculated fingerprints and cell data for everything

    #define descriptor (optional)
    fp = fingerprints(lmax=3,nmax=3,r_c=5.5)

    #setup and load network
    N = NetworkHandler(fp,name=netname)
    N.load()

    #setup the castep interface (don't really need it here, but it has everything conveniently stored)
    C = Castep_density(fp,N)
    #check to see the network was properly loaded
    C.setupNetwork()

    C.get_cell_data(cellname)

    grid = C.supercell.grid

    if (plane is None):
        plane=grid[1,2]
    else:
        idx = np.argmin(plane-grid[:,2])
        plane = grid[idx,2]

    plane_idx = np.where(plane==grid[:,2])[0]

    plane_FPs = C.supercell.FP[plane_idx,:]
    # allmean,std = C.ensemble_predict(C.supercell.FP)

    x = C.supercell.grid[plane_idx,0]
    y = C.supercell.grid[plane_idx,1]

    x,y = shiftxy(x,y)

    z, std_ = C.ensemble_predict(plane_FPs)

    filename = "{}/{}".format("FP_data",cellname)
    f = open(filename,'rb')
    dict = pickle.load(f)
    f.close()
    density=dict["density"]
    z_ = density[plane_idx]

    if (relative):
        z = np.abs((z-z_)/std_)
        #z = np.abs((z-z_)/z_)
        maxerr = np.max(z)
        vmax = min([3,maxerr])
    else:
        z = np.abs(z-z_)
        vmax = np.max(z)

    cm = plt.cm.get_cmap('RdYlBu')

    sc = plt.scatter(x,y,c=z,cmap=cm,marker=',',s=s_,alpha=1,vmax=vmax)
    plt.colorbar(sc,label='|Prediction Error/Uncertainty|')
    plt.ylabel("y ($\AA$)")
    plt.xlabel("x ($\AA$)")

    plt.show()
    return
 def set_network_handler(self, network_handler):
     if (isinstance(network_handler, NetworkHandler)):
         self.NetHandler = network_handler
     else:
         self.NetHandler = NetworkHandler(self.descriptor,
                                          train_dir=self.traindir)
class Castep_density():
    def __init__(self,
                 descriptor=None,
                 network_handler=None,
                 calc_FP=False,
                 trainNet=False,
                 data_dir='./data/',
                 train_dir='../CastepCalculations/DenNCells/'):
        self.traindir = train_dir
        self.datadir = data_dir
        self.trainNet = trainNet
        self.calc_FP = calc_FP
        self.supercell = None
        self.xcut = -6.7
        self.scale = 0.2
        self.set_descriptor(descriptor)
        self.set_network_handler(network_handler)

    def set_network_handler(self, network_handler):
        if (isinstance(network_handler, NetworkHandler)):
            self.NetHandler = network_handler
        else:
            self.NetHandler = NetworkHandler(self.descriptor,
                                             train_dir=self.traindir)

    def set_descriptor(self, descriptor):

        if (isinstance(descriptor, fingerprints)):
            self.descriptor = descriptor

        else:
            print(
                "descriptor passed was invalid, creating default fingerprints class"
            )
            self.descriptor = fingerprints()
        self.bilength = self.descriptor.bilength
        self.powerlength = self.descriptor.powerlength

        return

    def setupNetwork(self):
        #check if Network is set up
        #if not check if the network can be loaded
        #if not try to train one
        if (self.NetHandler.trained or self.NetHandler.loaded):
            return
        else:
            self.NetHandler.load()
            if (not self.NetHandler.loaded and self.trainNet):
                self.NetHandler.get_data()
                self.NetHandler.train()
            elif (not self.NetHandler.loaded and not self.trainNet):
                print(
                    "Warning: Network can't be loaded and hasn't been trained results will be bad"
                )
            return

    def get_frac_coords(self, at_posns, cell):
        #coordinates are absolute cartesians, want to put them in as fractional
        #at_posns[atom_idx,cartesian_element]
        frac_posns = np.zeros((at_posns.shape))
        inv_cell = np.linalg.inv(cell)
        for i in range(at_posns.shape[0]):
            frac_posns[i, :] = np.dot(inv_cell, at_posns[i, :])
        return frac_posns

    def taper(self, x):
        x_prime = (self.xcut - x) / self.scale
        #zeros = np.zeros((x_prime.shape))
        x_prime4 = (np.where(x_prime >= 0, x_prime, 0))**4
        taper = x_prime4 / (1 + x_prime4)
        return taper

    def get_cell_data(self,
                      filename,
                      cell_dir="./Cell_data/",
                      fp_dir="./FP_data/",
                      include_density=False):
        cell_keys = ["cell", "at_posns", "grid", "fin_density"]
        fp_keys = ["fingerprints", "density"]
        if (os.path.isdir(cell_dir)
                and os.path.isfile("{}{}".format(cell_dir, filename))):
            f = open("{}{}".format(cell_dir, filename), 'rb')
            cell_dict = pickle.load(f)
            f.close()
            if (all([key in cell_dict.keys() for key in cell_keys])):
                print("cell data loaded")
            else:
                print("cell data incomplete")
            if (os.path.isdir(fp_dir)
                    and os.path.isfile("{}{}".format(fp_dir, filename))):
                f = open("{}{}".format(fp_dir, filename), 'rb')
                fp_dict = pickle.load(f)
                f.close()
                if (all([key in fp_dict.keys() for key in fp_keys])):
                    print("fp data loaded")
                else:
                    print("fp data incomplete")

                self.set_supercell(cell_dict["cell"], cell_dict["at_posns"],
                                   cell_dict["grid"], fp_dict["fingerprints"])
                if (include_density):
                    self.supercell.train_density = fp_dict["density"]
                    self.supercell.fin_density = cell_dict["fin_density"]

            elif (self.calc_FP):
                print("not implemented yet")
            else:
                print("fp data not loaded")

        elif (os.path.isdir(self.datadir)):
            print("not implemented yet")

        else:
            print("no cell data, aborting load")

        return

    def set_supercell(self, cell, at_posns, grid, FP):
        #use Andrew's supercell class as the storage for calculations on each unit cell
        self.supercell = supercell()
        self.supercell.set_cell(cell)
        frac_coords = self.get_frac_coords(at_posns, cell)
        self.supercell.cart_coords = at_posns
        self.supercell.set_positions(frac_coords)
        at_species = ['H' for i in range(at_posns.shape[0])]
        self.supercell.set_species(at_species)
        if (len(grid) == FP.shape[0]):
            self.supercell.grid = grid
            self.supercell.FP = FP
        else:
            print("incompatible grid and fingerprints")
        return

    def ensemble_predict(self, X=None):

        if (X is None):
            ensemble_mean, ensemble_std = self.NetHandler.predict(
                self.supercell.FP, standard=False)

        else:
            ensemble_mean, ensemble_std = self.NetHandler.predict(
                X, standard=False)

        return ensemble_mean, ensemble_std

    def get_full_grid(self):
        #figures out the interval between grid points and creates a full grid for the cell
        #this only works for a cuboid cell

        #calculation of the interval assumes that grid points are far more likely to be in a clump than isolated
        diff = self.supercell.grid[1:, :] - self.supercell.grid[:-1, :]
        interval = []  # stats.mode(diff,axis=0)[0]
        for i in range(3):
            nonzero = diff[np.nonzero(diff[:, i]), i]
            inter = stats.mode(nonzero, axis=1)[0]
            interval.append(inter[0][0])
        print("interval: ", interval)

        #check to see if the cell vectors are divisible by the integer
        num_ints = [self.supercell.cell[i, i] / interval[i] for i in range(3)]
        grid_lines = []
        for i in range(3):
            if (num_ints[i] - round(num_ints[i]) < 0.0001):
                num_ints[i] = round(num_ints[i])
                grid_lines.append(
                    np.linspace(0.0, self.supercell.cell[i, i] - interval[i],
                                num_ints[i]))
            else:
                print("invalid interval")

        mesh = np.meshgrid(grid_lines[0], grid_lines[1], grid_lines[2])
        newmesh = []
        for i in range(3):
            newmesh.append(np.asarray(mesh[i]).flatten())
        fullgrid = np.asarray(newmesh)
        return fullgrid

    def getH(self, std):
        H = np.mean(np.log(std))
        return H

    def getnonzero_index(self, big_grid, small_grid):
        shape = big_grid.shape
        grid_len = int(np.cbrt(shape[0] + 1))
        frac_along = small_grid / big_grid[-1, 1]
        grid_index = ((grid_len - 1) * frac_along + 0.00001).astype(int)

        # print("assigning points...")
        # #figure out something that doesn't need a for loop, this is really slow
        # print(small_grid[921,:])
        # for i in range(small_grid.shape[0]):
        #     #print(i)
        #     idx.append(np.where(small_grid[i,:]==big_grid)[0][0])
        idx = (((grid_len**1) * grid_index[:, 0]) +
               ((grid_len**2) * grid_index[:, 1]) +
               ((grid_len**0) * grid_index[:, 2])).astype(int)
        idx = (grid_len * grid_index[:, 0] + grid_len**2 * grid_index[:, 1] +
               grid_index[:, 2]).astype(int)
        #print(big_grid[idx[1089],:])
        #print(small_grid[-40:,:])

        return idx

    def setCellDensities(self, filename=None, taper=True):
        #requires supercell to be set
        if (self.supercell is None):
            print("set the cell values before continuing")
            return

        ensemble_mean, ensemble_std = self.ensemble_predict()
        H = self.getH(ensemble_std)
        taper = self.taper(H)

        fullgrid = self.get_full_grid().T

        density = np.zeros((fullgrid.shape[0]))

        #dims = fullgrid.max(0) +1
        #nonzero_idx = np.where(np.in1d(np.ravel_multi_index(fullgrid.T,dims),np.ravel_multi_index(self.supercell.grid.T,dims)))[0]
        nonzero_idx = self.getnonzero_index(fullgrid, self.supercell.grid)

        if (len(nonzero_idx) == len(ensemble_mean)):
            density[nonzero_idx] = ensemble_mean[range(len(nonzero_idx))]
        else:
            print("messed up the assigning")
        if (taper):
            density *= taper
        # cell = self.supercell.cell
        # print(cell.shape)
        # cell_volume = abs(np.dot(cell[0],np.cross(cell[1],cell[2])))
        # print(cell_volume)
        # density /= cell_volume
        self.supercell.set_edensity({"xyz": fullgrid, "density": density})

        # print(self.supercell.grid[1089,:])
        # print(fullgrid[nonzero_idx[1089],:])
        # print(ensemble_mean[1089]*taper)
        # print(density[nonzero_idx[1089]])

        if (filename is not None):

            wrapper = wrap_inhouse(self.supercell)
            wrapper.write_cell(mp_grid_spacing=[2, 2, 2],
                               fname="Hnet_test.cell")
            wrapper.write_unformatted_density(fname=filename)
        return