Example #1
0
    def run(self, 
            logLock=None, logStreamFN=None, 
            lowerBoundLock=None, lowerBoundStreamFN=None,
            latentsTimes=None, latentsLock=None, latentsStreamFN=None,
           ):
        legQuadPoints, legQuadWeights = utils.svGPFA.miscUtils.getLegQuadPointsAndWeights(nQuad=self.getNQuad(), trialsLengths=self.getTrialsLengths())
        Z0 = utils.svGPFA.initUtils.getIndPointLocs0(nIndPointsPerLatent=self.getNIndPointsPerLatent(), trialsLengths=self.getTrialsLengths(), firstIndPoint=self.getFirstIndPoint())
        svPosteriorOnIndPointsParams0 = self.getSVPosteriorOnIndPointsParams0()
        svEmbeddingParams0 = self.getEmbeddingParams0()
        kernelsParams0 = self.getKernelsParams0() 
        kmsParams0 = {"kernelsParams0": kernelsParams0, "inducingPointsLocs0": Z0}
        initialParams = {"svPosteriorOnIndPoints": svPosteriorOnIndPointsParams0, "svEmbedding": svEmbeddingParams0, "kernelsMatricesStore": kmsParams0}
        quadParams = {"legQuadPoints": legQuadPoints, "legQuadWeights": legQuadWeights}

        conditionalDist = self.getConditionalDist()
        linkFunction = self.getLinkFunction()
        embeddingType = self.getEmbeddingType()
        kernels = self.getKernels()

        measurements = self.getSpikesTimes()
        optimParams = self.getOptimParams()
        indPointsLocsKMSEpsilon = self.getIndPointsLocsKMSRegEpsilon()


        # maximize lower bound
        svEM = stats.svGPFA.svEM.SVEM()

        model = stats.svGPFA.svGPFAModelFactory.SVGPFAModelFactory.buildModel(conditionalDist=conditionalDist, linkFunction=linkFunction, embeddingType=embeddingType, kernels=kernels)
        lowerBoundHist, elapsedTimeHist = svEM.maximize(model=model,
                                                        measurements=measurements,
                                                        initialParams=initialParams,
                                                        quadParams=quadParams,
                                                        optimParams=optimParams,
                                                        indPointsLocsKMSEpsilon=indPointsLocsKMSEpsilon,
                                                        logLock=logLock,
                                                        logStreamFN=logStreamFN,
                                                        lowerBoundLock=lowerBoundLock,
                                                        lowerBoundStreamFN=lowerBoundStreamFN,
                                                        latentsTimes=latentsTimes,
                                                        latentsLock=latentsLock,
                                                        latentsStreamFN=latentsStreamFN)
Example #2
0
def main(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument("simResNumber",
                        help="simuluation result number",
                        type=int)
    parser.add_argument("estInitNumber",
                        help="estimation init number",
                        type=int)
    args = parser.parse_args()

    simResNumber = args.simResNumber
    estInitNumber = args.estInitNumber

    estInitConfigFilename = "data/{:08d}_estimation_metaData.ini".format(
        estInitNumber)
    estInitConfig = configparser.ConfigParser()
    estInitConfig.read(estInitConfigFilename)
    nQuad = int(estInitConfig["control_variables"]["nQuad"])

    optimParamsConfig = estInitConfig._sections["optim_params"]
    optimParams = {}
    optimParams["em_max_iter"] = int(optimParamsConfig["em_max_iter"])
    steps = [
        "estep", "mstep_embedding", "mstep_kernels", "mstep_indpointslocs"
    ]
    for step in steps:
        optimParams["{:s}_estimate".format(step)] = optimParamsConfig[
            "{:s}_estimate".format(step)] == "True"
        optimParams["{:s}_max_iter".format(step)] = int(
            optimParamsConfig["{:s}_max_iter".format(step)])
        optimParams["{:s}_lr".format(step)] = float(
            optimParamsConfig["{:s}_lr".format(step)])
        optimParams["{:s}_tol".format(step)] = float(
            optimParamsConfig["{:s}_tol".format(step)])
        optimParams["{:s}_niter_display".format(step)] = int(
            optimParamsConfig["{:s}_niter_display".format(step)])
        optimParams["{:s}_line_search_fn".format(step)] = optimParamsConfig[
            "{:s}_line_search_fn".format(step)]
    optimParams["verbose"] = optimParamsConfig["verbose"] == "True"

    estPrefixUsed = True
    while estPrefixUsed:
        estResNumber = random.randint(0, 10**8)
        estimResMetaDataFilename = "results/{:08d}_estimation_metaData.ini".format(
            estResNumber)
        if not os.path.exists(estimResMetaDataFilename):
            estPrefixUsed = False
    modelSaveFilename = "results/{:08d}_estimatedModel.pickle".format(
        estResNumber)

    # load data and initial values
    simResConfigFilename = "results/{:08d}_simulation_metaData.ini".format(
        simResNumber)
    simResConfig = configparser.ConfigParser()
    simResConfig.read(simResConfigFilename)
    simInitConfigFilename = simResConfig["simulation_params"][
        "simInitConfigFilename"]
    simResFilename = simResConfig["simulation_results"]["simResFilename"]

    simInitConfig = configparser.ConfigParser()
    simInitConfig.read(simInitConfigFilename)
    nLatents = int(simInitConfig["control_variables"]["nLatents"])
    nNeurons = int(simInitConfig["control_variables"]["nNeurons"])
    trialsLengths = [
        float(str) for str in simInitConfig["control_variables"]
        ["trialsLengths"][1:-1].split(",")
    ]
    nTrials = len(trialsLengths)
    # firstIndPointLoc = float(simInitConfig["control_variables"]["firstIndPointLoc"])
    indPointsLocsKMSRegEpsilon = float(
        simInitConfig["control_variables"]["indPointsLocsKMSRegEpsilon"])

    with open(simResFilename, "rb") as f:
        simRes = pickle.load(f)
    spikesTimes = simRes["spikes"]
    KzzChol = simRes["KzzChol"]
    indPointsMeans = simRes["indPointsMeans"]
    randomEmbedding = estInitConfig["control_variables"][
        "randomEmbedding"].lower() == "true"
    if randomEmbedding:
        C0 = torch.rand(nNeurons, nLatents, dtype=torch.double) - 0.5 * 2
        d0 = torch.rand(nNeurons, 1, dtype=torch.double) - 0.5 * 2
    else:
        CFilename = estInitConfig["embedding_params"]["C_filename"]
        dFilename = estInitConfig["embedding_params"]["d_filename"]
        C, d = utils.svGPFA.configUtils.getLinearEmbeddingParams(
            CFilename=CFilename, dFilename=dFilename)
        initCondEmbeddingSTD = float(
            estInitConfig["control_variables"]["initCondEmbeddingSTD"])
        C0 = C + torch.randn(C.shape) * initCondEmbeddingSTD
        d0 = d + torch.randn(d.shape) * initCondEmbeddingSTD

    legQuadPoints, legQuadWeights = utils.svGPFA.miscUtils.getLegQuadPointsAndWeights(
        nQuad=nQuad, trialsLengths=trialsLengths)

    # kernels = utils.svGPFA.configUtils.getKernels(nLatents=nLatents, config=simInitConfig, forceUnitScale=True)
    # kernels = utils.svGPFA.configUtils.getKernels(nLatents=nLatents, config=estInitConfig, forceUnitScale=True)
    res = utils.svGPFA.configUtils.getScaledKernels(nLatents=nLatents,
                                                    config=estInitConfig,
                                                    forceUnitScale=True)
    kernels = res["kernels"]
    kernelsParamsScales = res["kernelsParamsScales"]
    unscaledKernelsParams0 = utils.svGPFA.initUtils.getKernelsParams0(
        kernels=kernels, noiseSTD=0.0)

    kernelsParams0 = []
    for i in range(len(unscaledKernelsParams0)):
        kernelsParams0.append(unscaledKernelsParams0[i] /
                              kernelsParamsScales[i])

    # Z0 = utils.svGPFA.configUtils.getIndPointsLocs0(nLatents=nLatents, nTrials=nTrials, config=simInitConfig)
    Z0 = utils.svGPFA.configUtils.getIndPointsLocs0(nLatents=nLatents,
                                                    nTrials=nTrials,
                                                    config=estInitConfig)
    nIndPointsPerLatent = [Z0[k].shape[1] for k in range(nLatents)]

    # patch to acommodate Lea's equal number of inducing points across trials
    qMu0 = [[] for k in range(nLatents)]
    for k in range(nLatents):
        qMu0[k] = torch.empty((nTrials, nIndPointsPerLatent[k], 1),
                              dtype=torch.double)
        for r in range(nTrials):
            qMu0[k][r, :, :] = indPointsMeans[r][k]
    # end patch

    srQSigma0Vecs = utils.svGPFA.initUtils.getSRQSigmaVecsFromSRMatrices(
        srMatrices=KzzChol)
    # epsilonSRQSigma0 = 1e4
    # srQSigma0s = []
    # for k in range(nLatents):
    #     srQSigma0sForLatent = torch.empty((nTrials, nIndPointsPerLatent[k], nIndPointsPerLatent[k]))
    #     for r in range(nTrials):
    #         srQSigma0sForLatent[r,:,:] = epsilonSRQSigma0*torch.eye(nIndPointsPerLatent[k])
    #     srQSigma0s.append(srQSigma0sForLatent)
    # srQSigma0Vecs = utils.svGPFA.initUtils.getSRQSigmaVecsFromSRMatrices(srMatrices=srQSigma0s)

    qUParams0 = {"qMu0": qMu0, "srQSigma0Vecs": srQSigma0Vecs}
    kmsParams0 = {
        "kernelsParams0": unscaledKernelsParams0,
        "inducingPointsLocs0": Z0
    }
    qKParams0 = {
        "svPosteriorOnIndPoints": qUParams0,
        "kernelsMatricesStore": kmsParams0
    }
    qHParams0 = {"C0": C0, "d0": d0}
    initialParams = {
        "svPosteriorOnLatents": qKParams0,
        "svEmbedding": qHParams0
    }
    quadParams = {
        "legQuadPoints": legQuadPoints,
        "legQuadWeights": legQuadWeights
    }

    kernelsTypes = [type(kernels[k]).__name__ for k in range(len(kernels))]
    qSVec0, qSDiag0 = utils.svGPFA.miscUtils.getQSVecsAndQSDiagsFromQSRSigmaVecs(
        srQSigmaVecs=srQSigma0Vecs)
    estimationDataForMatlabFilename = "results/{:08d}_estimationDataForMatlab.mat".format(
        estResNumber)

    #     utils.svGPFA.miscUtils.saveDataForMatlabEstimations(
    #         qMu0=qMu0, qSVec0=qSVec0, qSDiag0=qSDiag0,
    #         C0=C, d0=d,
    #         indPointsLocs0=Z0,
    #         legQuadPoints=legQuadPoints,
    #         legQuadWeights=legQuadWeights,
    #         kernelsTypes=kernelsTypes,
    #         kernelsParams0=kernelsParams0,
    #         spikesTimes=spikesTimes,
    #         indPointsLocsKMSRegEpsilon=indPointsLocsKMSRegEpsilon,
    #         trialsLengths=np.array(trialsLengths).reshape(-1,1),
    #         emMaxIter=optimParams["emMaxIter"],
    #         eStepMaxIter=optimParams["eStepMaxIter"],
    #         mStepEmbeddingMaxIter=optimParams["mStepEmbeddingMaxIter"],
    #         mStepKernelsMaxIter=optimParams["mStepKernelsMaxIter"],
    #         mStepIndPointsMaxIter=optimParams["mStepIndPointsMaxIter"],
    #         saveFilename=estimationDataForMatlabFilename)

    # create model
    model = stats.svGPFA.svGPFAModelFactory.SVGPFAModelFactory.buildModel(
        conditionalDist=stats.svGPFA.svGPFAModelFactory.PointProcess,
        linkFunction=stats.svGPFA.svGPFAModelFactory.ExponentialLink,
        embeddingType=stats.svGPFA.svGPFAModelFactory.LinearEmbedding,
        kernels=kernels)

    model.setInitialParamsAndData(
        measurements=spikesTimes,
        initialParams=initialParams,
        quadParams=quadParams,
        indPointsLocsKMSRegEpsilon=indPointsLocsKMSRegEpsilon)

    # maximize lower bound
    modelSaveFilename = "results/{:08d}_estimatedModel.pickle".format(
        estResNumber)
    savePartialFilenamePattern = "results/{:08d}_{{:s}}_estimatedModel.pickle".format(
        estResNumber)
    svEM = stats.svGPFA.svEM.SVEM()
    lowerBoundHist, elapsedTimeHist = svEM.maximize(
        model=model,
        optimParams=optimParams,
        savePartial=True,
        savePartialFilenamePattern=savePartialFilenamePattern)

    # save estimated values
    estimResConfig = configparser.ConfigParser()
    estimResConfig["simulation_params"] = {"simResNumber": simResNumber}
    estimResConfig["optim_params"] = optimParams
    estimResConfig["estimation_params"] = {
        "estInitNumber": estInitNumber,
        "nIndPointsPerLatent": nIndPointsPerLatent
    }
    with open(estimResMetaDataFilename, "w") as f:
        estimResConfig.write(f)

    resultsToSave = {
        "lowerBoundHist": lowerBoundHist,
        "elapsedTimeHist": elapsedTimeHist,
        "model": model
    }
    # with open(modelSaveFilename, "wb") as f: pickle.dump(resultsToSave, f)

    pdb.set_trace()
def main(argv):

    parser = argparse.ArgumentParser()
    parser.add_argument("simResNumber", help="simuluation result number", type=int)
    parser.add_argument("estInitNumber", help="estimation init number", type=int)
    args = parser.parse_args()

    simResNumber = args.simResNumber
    estInitNumber = args.estInitNumber

    estInitConfigFilename = "data/{:08d}_estimation_metaData.ini".format(estInitNumber)
    estInitConfig = configparser.ConfigParser()
    estInitConfig.read(estInitConfigFilename)
    nQuad = int(estInitConfig["control_variables"]["nQuad"])
    indPointsLocsKMSRegEpsilon = float(estInitConfig["control_variables"]["indPointsLocsKMSRegEpsilon"])

    optimParamsConfig = estInitConfig._sections["optim_params"]
    optimMethod = optimParamsConfig["em_method"]
    optimParams = {}
    optimParams["em_max_iter"] = int(optimParamsConfig["em_max_iter"])
    steps = ["estep", "mstep_embedding", "mstep_kernels", "mstep_indpointslocs"]
    for step in steps:
        optimParams["{:s}_estimate".format(step)] = optimParamsConfig["{:s}_estimate".format(step)]=="True"
        optimParams["{:s}_optim_params".format(step)] = {
            "max_iter": int(optimParamsConfig["{:s}_max_iter".format(step)]),
            "lr": float(optimParamsConfig["{:s}_lr".format(step)]),
            "tolerance_grad": float(optimParamsConfig["{:s}_tolerance_grad".format(step)]),
            "tolerance_change": float(optimParamsConfig["{:s}_tolerance_change".format(step)]),
            "line_search_fn": optimParamsConfig["{:s}_line_search_fn".format(step)],
        }
    optimParams["verbose"] = optimParamsConfig["verbose"]=="True"

    # load data and initial values
    simResConfigFilename = "results/{:08d}_simulation_metaData.ini".format(simResNumber)
    simResConfig = configparser.ConfigParser()
    simResConfig.read(simResConfigFilename)
    simInitConfigFilename = simResConfig["simulation_params"]["simInitConfigFilename"]
    simResFilename = simResConfig["simulation_results"]["simResFilename"]

    simInitConfig = configparser.ConfigParser()
    simInitConfig.read(simInitConfigFilename)
    nLatents = int(simInitConfig["control_variables"]["nLatents"])
    nNeurons = int(simInitConfig["control_variables"]["nNeurons"])
    trialsLengths = [float(str) for str in simInitConfig["control_variables"]["trialsLengths"][1:-1].split(",")]
    nTrials = len(trialsLengths)

    with open(simResFilename, "rb") as f: simRes = pickle.load(f)
    spikesTimes = simRes["spikes"]

    randomEmbedding = estInitConfig["control_variables"]["randomEmbedding"].lower()=="true"
    if randomEmbedding:
        C0 = torch.rand(nNeurons, nLatents, dtype=torch.double).contiguous()
        d0 = torch.rand(nNeurons, 1, dtype=torch.double).contiguous()
    else:
        CFilename = estInitConfig["embedding_params"]["C_filename"]
        dFilename = estInitConfig["embedding_params"]["d_filename"]
        C, d = utils.svGPFA.configUtils.getLinearEmbeddingParams(CFilename=CFilename, dFilename=dFilename)
        initCondEmbeddingSTD = float(estInitConfig["control_variables"]["initCondEmbeddingSTD"])
        C0 = (C + torch.randn(C.shape)*initCondEmbeddingSTD).contiguous()
        d0 = (d + torch.randn(d.shape)*initCondEmbeddingSTD).contiguous()

    legQuadPoints, legQuadWeights = utils.svGPFA.miscUtils.getLegQuadPointsAndWeights(nQuad=nQuad, trialsLengths=trialsLengths)

    # kernels = utils.svGPFA.configUtils.getScaledKernels(nLatents=nLatents, config=estInitConfig, forceUnitScale=True)["kernels"]
    kernels = utils.svGPFA.configUtils.getKernels(nLatents=nLatents, config=estInitConfig, forceUnitScale=True)
    kernelsScaledParams0 = utils.svGPFA.initUtils.getKernelsScaledParams0(kernels=kernels, noiseSTD=0.0)
    Z0 = utils.svGPFA.configUtils.getIndPointsLocs0(nLatents=nLatents, nTrials=nTrials, config=estInitConfig)
    nIndPointsPerLatent = [Z0[k].shape[1] for k in range(nLatents)]

    qMu0 = utils.svGPFA.configUtils.getVariationalMean0(nLatents=nLatents, nTrials=nTrials, config=estInitConfig)
#     indPointsMeans = utils.svGPFA.configUtils.getVariationalMean0(nLatents=nLatents, nTrials=nTrials, config=estInitConfig)
#     # patch to acommodate Lea's equal number of inducing points across trials
#     qMu0 = [[] for k in range(nLatents)]
#     for k in range(nLatents):
#         qMu0[k] = torch.empty((nTrials, nIndPointsPerLatent[k], 1), dtype=torch.double)
#         for r in range(nTrials):
#             qMu0[k][r,:,:] = indPointsMeans[k][r]
#     # end patch

    qSigma0 = utils.svGPFA.configUtils.getVariationalCov0(nLatents=nLatents, nTrials=nTrials, config=estInitConfig)
    srQSigma0Vecs = utils.svGPFA.initUtils.getSRQSigmaVecsFromSRMatrices(srMatrices=qSigma0)

    qUParams0 = {"qMu0": qMu0, "srQSigma0Vecs": srQSigma0Vecs}
    kmsParams0 = {"kernelsParams0": kernelsScaledParams0,
                  "inducingPointsLocs0": Z0}
    qKParams0 = {"svPosteriorOnIndPoints": qUParams0,
                 "kernelsMatricesStore": kmsParams0}
    qHParams0 = {"C0": C0, "d0": d0}
    initialParams = {"svPosteriorOnLatents": qKParams0,
                     "svEmbedding": qHParams0}
    quadParams = {"legQuadPoints": legQuadPoints,
                  "legQuadWeights": legQuadWeights}

    estPrefixUsed = True
    while estPrefixUsed:
        estResNumber = random.randint(0, 10**8)
        estimResMetaDataFilename = "results/{:08d}_estimation_metaData.ini".format(estResNumber)
        if not os.path.exists(estimResMetaDataFilename):
           estPrefixUsed = False
    modelSaveFilename = "results/{:08d}_estimatedModel.pickle".format(estResNumber)

    kernelsTypes = [type(kernels[k]).__name__ for k in range(len(kernels))]
    qSVec0, qSDiag0 = utils.svGPFA.miscUtils.getQSVecsAndQSDiagsFromQSRSigmaVecs(srQSigmaVecs=srQSigma0Vecs)
    estimationDataForMatlabFilename = "results/{:08d}_estimationDataForMatlab.mat".format(estResNumber)
    if "latentsTrialsTimes" in simRes.keys():
        latentsTrialsTimes = simRes["latentsTrialsTimes"]
    elif "times" in simRes.keys():
        latentsTrialsTimes = simRes["times"]
    else:
        raise ValueError("latentsTrialsTimes or times cannot be found in {:s}".format(simResFilename))
    utils.svGPFA.miscUtils.saveDataForMatlabEstimations(
        qMu0=qMu0, qSVec0=qSVec0, qSDiag0=qSDiag0,
        C0=C0, d0=d0,
        indPointsLocs0=Z0,
        legQuadPoints=legQuadPoints,
        legQuadWeights=legQuadWeights,
        kernelsTypes=kernelsTypes,
        kernelsParams0=kernelsScaledParams0,
        spikesTimes=spikesTimes,
        indPointsLocsKMSRegEpsilon=indPointsLocsKMSRegEpsilon,
        trialsLengths=torch.tensor(trialsLengths).reshape(-1,1),
        latentsTrialsTimes=latentsTrialsTimes,
        emMaxIter=optimParams["em_max_iter"],
        eStepMaxIter=optimParams["estep_optim_params"]["max_iter"],
        mStepEmbeddingMaxIter=optimParams["mstep_embedding_optim_params"]["max_iter"],
        mStepKernelsMaxIter=optimParams["mstep_kernels_optim_params"]["max_iter"],
        mStepIndPointsMaxIter=optimParams["mstep_indpointslocs_optim_params"]["max_iter"],
        saveFilename=estimationDataForMatlabFilename)

    def getKernelParams(model):
        kernelParams = model.getKernelsParams()[0]
        return kernelParams

    # create model
    model = stats.svGPFA.svGPFAModelFactory.SVGPFAModelFactory.buildModel(
        conditionalDist=stats.svGPFA.svGPFAModelFactory.PointProcess,
        linkFunction=stats.svGPFA.svGPFAModelFactory.ExponentialLink,
        embeddingType=stats.svGPFA.svGPFAModelFactory.LinearEmbedding,
        kernels=kernels)

    model.setInitialParamsAndData(measurements=spikesTimes,
                                  initialParams=initialParams,
                                  eLLCalculationParams=quadParams,
                                  indPointsLocsKMSRegEpsilon=indPointsLocsKMSRegEpsilon)

    # maximize lower bound
    svEM = stats.svGPFA.svEM.SVEM()
    lowerBoundHist, elapsedTimeHist, terminationInfo, iterationsModelParams = svEM.maximize(model=model, optimParams=optimParams, method=optimMethod, getIterationModelParamsFn=getKernelParams)

    # save estimated values
    estimResConfig = configparser.ConfigParser()
    estimResConfig["simulation_params"] = {"simResNumber": simResNumber}
    estimResConfig["optim_params"] = optimParams
    estimResConfig["estimation_params"] = {"estInitNumber": estInitNumber, "nIndPointsPerLatent": nIndPointsPerLatent}
    with open(estimResMetaDataFilename, "w") as f: estimResConfig.write(f)

    resultsToSave = {"lowerBoundHist": lowerBoundHist, "elapsedTimeHist": elapsedTimeHist, "terminationInfo": terminationInfo, "iterationModelParams": iterationsModelParams, "model": model}
    with open(modelSaveFilename, "wb") as f: pickle.dump(resultsToSave, f)
    print("Saved results to {:s}".format(modelSaveFilename))

    pdb.set_trace()
def main(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "initialEstResNumber",
        help="estimation result number of the model to add iterations",
        type=int)
    parser.add_argument("nIter", help="number of iterations to add", type=int)
    parser.add_argument("--estimatedModelFilenamePattern",
                        default="results/{:08d}_estimatedModel.pickle",
                        help="estimated model filename pattern")
    parser.add_argument("--estimationMetaDataFilenamePattern",
                        default="data/{:08d}_estimation_metaData.ini",
                        help="estimation model meta data filename pattern")
    parser.add_argument("--estimatedModelMetaDataFilenamePattern",
                        default="results/{:08d}_estimation_metaData.ini",
                        help="estimated model meta data filename pattern")
    args = parser.parse_args()

    initialEstResNumber = args.initialEstResNumber
    nIter = args.nIter
    estimatedModelFilenamePattern = args.estimatedModelFilenamePattern
    estimationMetaDataFilenamePattern = args.estimationMetaDataFilenamePattern
    estimatedModelMetaDataFilenamePattern = args.estimatedModelMetaDataFilenamePattern

    initialEstModelMetaDataFilename = estimatedModelMetaDataFilenamePattern.format(
        initialEstResNumber)
    initialEstModelMetaDataConfig = configparser.ConfigParser()
    initialEstModelMetaDataConfig.read(initialEstModelMetaDataFilename)
    initialEstimationInitNumber = int(
        initialEstModelMetaDataConfig["estimation_params"][
            "estInitNumber".lower()])
    estMetaDataFilename = estimationMetaDataFilenamePattern.format(
        initialEstimationInitNumber)

    initialEstimationMetaDataConfig = configparser.ConfigParser()
    initialEstimationMetaDataConfig.read(estMetaDataFilename)
    optimParamsDict = initialEstimationMetaDataConfig._sections["optim_params"]
    optimParams = utils.svGPFA.miscUtils.getOptimParams(
        optimParamsDict=optimParamsDict)
    optimParams["em_max_iter"] = nIter

    initialModelFilename = estimatedModelFilenamePattern.format(
        initialEstResNumber)
    with open(initialModelFilename, "rb") as f:
        estResults = pickle.load(f)
    model = estResults["model"]

    estPrefixUsed = True
    while estPrefixUsed:
        finalEstResNumber = random.randint(0, 10**8)
        finalEstimResMetaDataFilename = estimatedModelMetaDataFilenamePattern.format(
            finalEstResNumber)
        if not os.path.exists(finalEstimResMetaDataFilename):
            estPrefixUsed = False
    finalModelSaveFilename = estimatedModelFilenamePattern.format(
        finalEstResNumber)

    # maximize lower bound
    svEM = stats.svGPFA.svEM.SVEM_PyTorch()
    lowerBoundHist, elapsedTimeHist, terminationInfo, iterationsModelParams = \
        svEM.maximize(model=model, optimParams=optimParams)

    # save estimated values
    finalEstimationMetaDataConfig = configparser.ConfigParser()
    finalEstimationMetaDataConfig["estimation_params"] = {
        "initialEstResNumber": initialEstResNumber,
        "nIter": nIter,
        "estInitNumber": initialEstimationInitNumber
    }
    with open(finalEstimResMetaDataFilename, "w") as f:
        finalEstimationMetaDataConfig.write(f)

    resultsToSave = {
        "lowerBoundHist": lowerBoundHist,
        "elapsedTimeHist": elapsedTimeHist,
        "terminationInfo": terminationInfo,
        "iterationModelParams": iterationsModelParams,
        "model": model
    }
    with open(finalModelSaveFilename, "wb") as f:
        pickle.dump(resultsToSave, f)

    pdb.set_trace()
Example #5
0
def test_maximize_pointProcess_PyTorch():
    tol = 1e-5
    yNonStackedFilename = os.path.join(os.path.dirname(__file__),
                                       "data/YNonStacked.mat")
    dataFilename = os.path.join(os.path.dirname(__file__),
                                "data/variationalEM.mat")

    mat = loadmat(dataFilename)
    nLatents = len(mat['Z0'])
    nTrials = mat['Z0'][0, 0].shape[2]
    qMu0 = [
        torch.from_numpy(mat['q_mu0'][(0, i)]).type(
            torch.DoubleTensor).permute(2, 0, 1).contiguous()
        for i in range(nLatents)
    ]
    qSVec0 = [
        torch.from_numpy(mat['q_sqrt0'][(0, i)]).type(
            torch.DoubleTensor).permute(2, 0, 1).contiguous()
        for i in range(nLatents)
    ]
    qSDiag0 = [
        torch.from_numpy(mat['q_diag0'][(0, i)]).type(
            torch.DoubleTensor).permute(2, 0, 1).contiguous()
        for i in range(nLatents)
    ]
    srQSigma0Vecs = utils.svGPFA.miscUtils.getSRQSigmaVec(qSVec=qSVec0,
                                                          qSDiag=qSDiag0)
    Z0 = [
        torch.from_numpy(mat['Z0'][(i, 0)]).type(torch.DoubleTensor).permute(
            2, 0, 1).contiguous() for i in range(nLatents)
    ]
    C0 = torch.from_numpy(mat["C0"]).type(torch.DoubleTensor).contiguous()
    b0 = torch.from_numpy(mat["b0"]).type(
        torch.DoubleTensor).squeeze().contiguous()
    indPointsLocsKMSRegEpsilon = 1e-2
    legQuadPoints = torch.from_numpy(mat['ttQuad']).type(
        torch.DoubleTensor).permute(2, 0, 1).contiguous()
    legQuadWeights = torch.from_numpy(mat['wwQuad']).type(
        torch.DoubleTensor).permute(2, 0, 1).contiguous()

    yMat = loadmat(yNonStackedFilename)
    YNonStacked_tmp = yMat['YNonStacked']
    nNeurons = YNonStacked_tmp[0, 0].shape[0]
    YNonStacked = [[[] for n in range(nNeurons)] for r in range(nTrials)]
    for r in range(nTrials):
        for n in range(nNeurons):
            YNonStacked[r][n] = torch.from_numpy(
                YNonStacked_tmp[r,
                                0][n,
                                   0][:,
                                      0]).contiguous().type(torch.DoubleTensor)

    linkFunction = torch.exp

    kernelNames = mat["kernelNames"]
    hprs = mat["hprs0"]
    leasLowerBound = mat['lowerBound'][0, 0]
    kernels = [[None] for k in range(nLatents)]
    kernelsParams0 = [[None] for k in range(nLatents)]
    for k in range(nLatents):
        if np.char.equal(kernelNames[0, k][0], "PeriodicKernel"):
            kernels[k] = stats.kernels.PeriodicKernel(scale=1.0)
            kernelsParams0[k] = torch.tensor(
                [float(hprs[k, 0][0]),
                 float(hprs[k, 0][1])],
                dtype=torch.double)
        elif np.char.equal(kernelNames[0, k][0], "rbfKernel"):
            kernels[k] = stats.kernels.ExponentialQuadraticKernel(scale=1.0)
            kernelsParams0[k] = torch.tensor([float(hprs[k, 0][0])],
                                             dtype=torch.double)
        else:
            raise ValueError("Invalid kernel name: %s" % (kernelNames[k]))

    qU = stats.svGPFA.svPosteriorOnIndPoints.SVPosteriorOnIndPointsChol()
    indPointsLocsKMS = stats.svGPFA.kernelsMatricesStore.IndPointsLocsKMS_Chol(
    )
    indPointsLocsAndAllTimesKMS = stats.svGPFA.kernelsMatricesStore.IndPointsLocsAndAllTimesKMS(
    )
    indPointsLocsAndAssocTimesKMS = stats.svGPFA.kernelsMatricesStore.IndPointsLocsAndAssocTimesKMS(
    )
    qKAllTimes = stats.svGPFA.svPosteriorOnLatents.SVPosteriorOnLatentsAllTimes(
        svPosteriorOnIndPoints=qU,
        indPointsLocsKMS=indPointsLocsKMS,
        indPointsLocsAndTimesKMS=indPointsLocsAndAllTimesKMS)
    qKAssocTimes = stats.svGPFA.svPosteriorOnLatents.SVPosteriorOnLatentsAssocTimes(
        svPosteriorOnIndPoints=qU,
        indPointsLocsKMS=indPointsLocsKMS,
        indPointsLocsAndTimesKMS=indPointsLocsAndAssocTimesKMS)
    qHAllTimes = stats.svGPFA.svEmbedding.LinearSVEmbeddingAllTimes(
        svPosteriorOnLatents=qKAllTimes)
    qHAssocTimes = stats.svGPFA.svEmbedding.LinearSVEmbeddingAssocTimes(
        svPosteriorOnLatents=qKAssocTimes)
    eLL = stats.svGPFA.expectedLogLikelihood.PointProcessELLExpLink(
        svEmbeddingAllTimes=qHAllTimes, svEmbeddingAssocTimes=qHAssocTimes)
    klDiv = stats.svGPFA.klDivergence.KLDivergence(
        indPointsLocsKMS=indPointsLocsKMS, svPosteriorOnIndPoints=qU)
    svlb = stats.svGPFA.svLowerBound.SVLowerBound(eLL=eLL, klDiv=klDiv)
    svlb.setKernels(kernels=kernels)
    svEM = stats.svGPFA.svEM.SVEM_PyTorch()

    qUParams0 = {"qMu0": qMu0, "srQSigma0Vecs": srQSigma0Vecs}
    kmsParams0 = {"kernelsParams0": kernelsParams0, "inducingPointsLocs0": Z0}
    qKParams0 = {
        "svPosteriorOnIndPoints": qUParams0,
        "kernelsMatricesStore": kmsParams0
    }
    qHParams0 = {"C0": C0, "d0": b0}
    initialParams = {
        "svPosteriorOnLatents": qKParams0,
        "svEmbedding": qHParams0
    }
    eLLCalculationParams = {
        "legQuadPoints": legQuadPoints,
        "legQuadWeights": legQuadWeights
    }

    optimParams = {
        "em_max_iter": 4,
        #
        "estep_estimate": True,
        "estep_optim_params": {
            "max_iter": 20,
            "line_search_fn": "strong_wolfe"
        },
        #
        "mstep_embedding_estimate": True,
        "mstep_embedding_optim_params": {
            "max_iter": 20,
            "line_search_fn": "strong_wolfe"
        },
        #
        "mstep_kernels_estimate": True,
        "mstep_kernels_optim_params": {
            "max_iter": 20,
            "line_search_fn": "strong_wolfe"
        },
        #
        "mstep_indpointslocs_estimate": True,
        "mstep_indpointslocs_optim_params": {
            "max_iter": 20,
            "line_search_fn": "strong_wolfe"
        },
        #
        "verbose": True
    }
    svlb.setInitialParamsAndData(
        measurements=YNonStacked,
        initialParams=initialParams,
        eLLCalculationParams=eLLCalculationParams,
        indPointsLocsKMSRegEpsilon=indPointsLocsKMSRegEpsilon)
    lowerBoundHist, _, _, _ = svEM.maximize(model=svlb,
                                            optimParams=optimParams)
    assert (lowerBoundHist[-1] > leasLowerBound)
Example #6
0
def main(argv):
    simResNumber = 96784468
    estInitNumber = 5
    trialToPlot = 0
    neuronToPlot = 18
    dtCIF = 1e-3
    gamma = 10

    # load data and initial values
    simResConfigFilename = "../scripts/results/{:08d}_simulation_metaData.ini".format(simResNumber)
    simResConfig = configparser.ConfigParser()
    simResConfig.read(simResConfigFilename)
    simInitConfigFilename = simResConfig["simulation_params"]["simInitConfigFilename"]
    simInitConfigFilename = "../scripts/" + simInitConfigFilename
    simResFilename = simResConfig["simulation_results"]["simResFilename"]
    simResFilename = "../scripts/" + simResFilename

    simInitConfig = configparser.ConfigParser()
    simInitConfig.read(simInitConfigFilename)
    nLatents = int(simInitConfig["control_variables"]["nLatents"])
    nNeurons = int(simInitConfig["control_variables"]["nNeurons"])
    trialsLengths = [float(str) for str in simInitConfig["control_variables"]["trialsLengths"][1:-1].split(",")]
    dtSimulate = float(simInitConfig["control_variables"]["dt"])
    nTrials = len(trialsLengths)

    with open(simResFilename, "rb") as f: simRes = pickle.load(f)
    spikesTimes = simRes["spikes"]
    trueLatents = simRes["latents"]
    simCIFsValues = simRes["cifValues"]
    trueLatents = [trueLatents[r][:nLatents,:] for r in range(nTrials)]
    trueLatentsMeans = simRes["latentsMeans"]
    trueLatentsSTDs = simRes["latentsSTDs"]
    trueLatentsSTDs = [trueLatentsSTDs[r][:nLatents,:] for r in range(nTrials)]
    timesTrueValues = torch.linspace(0, torch.max(torch.tensor(trialsLengths)), trueLatents[0].shape[1])

    estInitConfigFilename = "../scripts/data/{:08d}_estimation_metaData.ini".format(estInitNumber)
    estInitConfig = configparser.ConfigParser()
    estInitConfig.read(estInitConfigFilename)
    nIndPointsPerLatent = [int(str) for str in estInitConfig["control_variables"]["nIndPointsPerLatent"][1:-1].split(",")]
    nIndPointsPerLatent = nIndPointsPerLatent[:nLatents]
    nTestPoints = int(estInitConfig["control_variables"]["nTestPoints"])
    firstIndPoint = float(estInitConfig["control_variables"]["firstIndPoint"])
    initCondEmbeddingSTD = float(estInitConfig["control_variables"]["initCondEmbeddingSTD"])
    initCondIndPointsScale = float(estInitConfig["control_variables"]["initCondIndPointsScale"])
    kernelsParams0NoiseSTD = float(estInitConfig["control_variables"]["kernelsParams0NoiseSTD"])
    indPointsLocsKMSRegEpsilon = float(estInitConfig["control_variables"]["indPointsLocsKMSRegEpsilon"])
    nQuad = int(estInitConfig["control_variables"]["nQuad"])

    optimParamsConfig = estInitConfig._sections["optim_params"]
    optimParams = {}
    optimParams["emMaxNIter"] = int(optimParamsConfig["emMaxNIter".lower()])
    #
    optimParams["eStepEstimate"] = optimParamsConfig["eStepEstimate".lower()]=="True"
    optimParams["eStepMaxNIter"] = int(optimParamsConfig["eStepMaxNIter".lower()])
    optimParams["eStepTol"] = float(optimParamsConfig["eStepTol".lower()])
    optimParams["eStepLR"] = float(optimParamsConfig["eStepLR".lower()])
    optimParams["eStepLineSearchFn"] = optimParamsConfig["eStepLineSearchFn".lower()]
    optimParams["eStepNIterDisplay"] = int(optimParamsConfig["eStepNIterDisplay".lower()])
    #
    optimParams["mStepModelParamsEstimate"] = optimParamsConfig["mStepModelParamsEstimate".lower()]=="True"
    optimParams["mStepModelParamsMaxNIter"] = int(optimParamsConfig["mStepModelParamsMaxNIter".lower()])
    optimParams["mStepModelParamsTol"] = float(optimParamsConfig["mStepModelParamsTol".lower()])
    optimParams["mStepModelParamsLR"] = float(optimParamsConfig["mStepModelParamsLR".lower()])
    optimParams["mStepModelParamsLineSearchFn"] = optimParamsConfig["mStepModelParamsLineSearchFn".lower()]
    optimParams["mStepModelParamsNIterDisplay"] = int(optimParamsConfig["mStepModelParamsNIterDisplay".lower()])
    #
    optimParams["mStepKernelParamsEstimate"] = optimParamsConfig["mStepKernelParamsEstimate".lower()]=="True"
    optimParams["mStepKernelParamsMaxNIter"] = int(optimParamsConfig["mStepKernelParamsMaxNIter".lower()])
    optimParams["mStepKernelParamsTol"] = float(optimParamsConfig["mStepKernelParamsTol".lower()])
    optimParams["mStepKernelParamsLR"] = float(optimParamsConfig["mStepKernelParamsLR".lower()])
    optimParams["mStepKernelParamsLineSearchFn"] = optimParamsConfig["mStepKernelParamsLineSearchFn".lower()]
    optimParams["mStepKernelParamsNIterDisplay"] = int(optimParamsConfig["mStepKernelParamsNIterDisplay".lower()])
    #
    optimParams["mStepIndPointsEstimate"] = optimParamsConfig["mStepIndPointsEstimate".lower()]="True"
    optimParams["mStepIndPointsMaxNIter"] = int(optimParamsConfig["mStepIndPointsMaxNIter".lower()])
    optimParams["mStepIndPointsTol"] = float(optimParamsConfig["mStepIndPointsTol".lower()])
    optimParams["mStepIndPointsLR"] = float(optimParamsConfig["mStepIndPointsLR".lower()])
    optimParams["mStepIndPointsLineSearchFn"] = optimParamsConfig["mStepIndPointsLineSearchFn".lower()]
    optimParams["mStepIndPointsNIterDisplay"] = int(optimParamsConfig["mStepIndPointsNIterDisplay".lower()])
    #
    optimParams["verbose"] = optimParamsConfig["verbose"]=="True"

    testTimes = torch.linspace(0, torch.max(torch.tensor(spikesTimes[0][0])), nTestPoints)

    CFilename = "../scripts/" + simInitConfig["embedding_params"]["C_filename"]
    dFilename = "../scripts/" + simInitConfig["embedding_params"]["d_filename"]
    C, d = utils.svGPFA.configUtils.getLinearEmbeddingParams(nNeurons=nNeurons, nLatents=nLatents, CFilename=CFilename, dFilename=dFilename)
    C0 = C + torch.randn(C.shape)*initCondEmbeddingSTD
    C0 = C0[:,:nLatents]
    d0 = d + torch.randn(d.shape)*initCondEmbeddingSTD

    legQuadPoints, legQuadWeights = demoUtils.getLegQuadPointsAndWeights(nQuad=nQuad, trialsLengths=trialsLengths)

    kernels = utils.svGPFA.configUtils.getKernels(nLatents=nLatents, nTrials=nTrials, config=simInitConfig)
    kernelsParams0 = demoUtils.getKernelsParams0(kernels=kernels, noiseSTD=kernelsParams0NoiseSTD)
    kernels = kernels[0] # the current code uses the same kernels for all trials
    kernelsParams0 = kernelsParams0[0] # the current code uses the same kernels for all trials

    qMu0, qSVec0, qSDiag0 = demoUtils.getSVPosteriorOnIndPointsParams0(nIndPointsPerLatent=nIndPointsPerLatent, nLatents=nLatents, nTrials=nTrials, scale=initCondIndPointsScale)

    Z0 = demoUtils.getIndPointLocs0(nIndPointsPerLatent=nIndPointsPerLatent,
                          trialsLengths=trialsLengths, firstIndPoint=firstIndPoint)

    qUParams0 = {"qMu0": qMu0, "qSVec0": qSVec0, "qSDiag0": qSDiag0}
    qHParams0 = {"C0": C0, "d0": d0}
    kmsParams0 = {"kernelsParams0": kernelsParams0,
                  "inducingPointsLocs0": Z0}
    initialParams = {"svPosteriorOnIndPoints": qUParams0,
                     "kernelsMatricesStore": kmsParams0,
                     "svEmbedding": qHParams0}
    quadParams = {"legQuadPoints": legQuadPoints,
                  "legQuadWeights": legQuadWeights}

    # create model
    model = stats.svGPFA.svGPFAModelFactory.SVGPFAModelFactory.buildModel(
        conditionalDist=stats.svGPFA.svGPFAModelFactory.PointProcess,
        linkFunction=stats.svGPFA.svGPFAModelFactory.ExponentialLink,
        embeddingType=stats.svGPFA.svGPFAModelFactory.LinearEmbedding,
        kernels=kernels,
        indPointsLocsKMSEpsilon=indPointsLocsKMSRegEpsilon)

    # maximize lower bound
    svEM = stats.svGPFA.svEM.SVEM()
    lowerBoundHist, elapsedTimeHist  = svEM.maximize(
        model=model, measurements=spikesTimes, initialParams=initialParams,
        quadParams=quadParams, optimParams=optimParams,
        plotLatentsEstimates=False)

    # plot lower bound history
    plot.svGPFA.plotUtils.plotLowerBoundHist(lowerBoundHist=lowerBoundHist)

    # plot true and estimated latents
    testMuK, testVarK = model.predictLatents(newTimes=testTimes)
    indPointsLocs = model.getIndPointsLocs()
    plot.svGPFA.plotUtils.plotTrueAndEstimatedLatents(timesEstimatedValues=testTimes, muK=testMuK, varK=testVarK, indPointsLocs=indPointsLocs, timesTrueValues=timesTrueValues, trueLatents=trueLatents, trueLatentsMeans=trueLatentsMeans, trueLatentsSTDs=trueLatentsSTDs, trialToPlot=trialToPlot)

    # plot model params
    tLatentsMeansFuncs = utils.svGPFA.configUtils.getLatentsMeansFuncs(nLatents=nLatents, nTrials=nTrials, config=simInitConfig)
    trialsTimes = utils.svGPFA.miscUtils.getTrialsTimes(trialsLengths=trialsLengths, dt=dtSimulate)
    tLatentsMeans = utils.svGPFA.miscUtils.getLatentsMeanFuncsSamples(latentsMeansFuncs=tLatentsMeansFuncs, trialsTimes=trialsTimes, dtype=C.dtype)
    kernelsParams = model.getKernelsParams()
    with torch.no_grad(): latentsMeans, _ = model.predictLatents(newTimes=trialsTimes[0])
    with torch.no_grad(): estimatedC, estimatedD = model.getSVEmbeddingParams()

    plot.svGPFA.plotUtils.plotTrueAndEstimatedKernelsParams(trueKernels=kernels, estimatedKernelsParams=kernelsParams)

    plot.svGPFA.plotUtils.plotTrueAndEstimatedLatentsMeans(trueLatentsMeans=tLatentsMeans, estimatedLatentsMeans=latentsMeans, trialsTimes=trialsTimes)

    plot.svGPFA.plotUtils.plotTrueAndEstimatedEmbeddingParams(trueC=C, trueD=d, estimatedC=estimatedC, estimatedD=estimatedD)

    title = "Trial {:d}, Neuron {:d}".format(trialToPlot, neuronToPlot)
    # CIF
    T = torch.tensor(trialsLengths).max().item()
    oneTrialCIFTimes = torch.arange(0, T, dtCIF)
    cifTimes = torch.unsqueeze(torch.ger(torch.ones(nTrials), oneTrialCIFTimes), dim=2)
    with torch.no_grad(): cifValues = model.computeMeanCIFs(times=cifTimes)
    plot.svGPFA.plotUtils.plotSimulatedAndEstimatedCIFs(times=cifTimes[trialToPlot, :, 0], simCIFValues=simCIFsValues[trialToPlot][neuronToPlot], estCIFValues=cifValues[trialToPlot][neuronToPlot], title=title)

    # KS test time rescaling with numerical correction
    spikesTimesKS = spikesTimes[trialToPlot][neuronToPlot]
    cifTimesKS = cifTimes[trialToPlot,:,0]
    cifValuesKS = cifValues[trialToPlot][neuronToPlot]

    title = "Trial {:d}, Neuron {:d} ({:d} spikes)".format(trialToPlot, neuronToPlot, len(spikesTimesKS))

    diffECDFsX, diffECDFsY, estECDFx, estECDFy, simECDFx, simECDFy, cb = stats.pointProcess.tests.KSTestTimeRescalingNumericalCorrection(spikesTimes=spikesTimesKS, cifTimes=cifTimesKS, cifValues=cifValuesKS, gamma=gamma)
    plot.svGPFA.plotUtils.plotResKSTestTimeRescalingNumericalCorrection(diffECDFsX=diffECDFsX, diffECDFsY=diffECDFsY, estECDFx=estECDFx, estECDFy=estECDFy, simECDFx=simECDFx, simECDFy=simECDFy, cb=cb, title=title)

    # KS test time rescaling with analytical correction
    t0 = math.floor(cifTimesKS.min())
    tf = math.ceil(cifTimesKS.max())
    dt = (cifTimesKS[1]-cifTimesKS[0]).item()
    utSRISIs, uCDF, cb, utRISIs = stats.pointProcess.tests.KSTestTimeRescalingAnalyticalCorrectionUnbinned(spikesTimes=spikesTimesKS, cifValues=cifValuesKS, t0=t0, tf=tf, dt=dt)
    sUTRISIs, _ = torch.sort(utSRISIs)

    plot.svGPFA.plotUtils.plotResKSTestTimeRescalingAnalyticalCorrection(sUTRISIs=sUTRISIs, uCDF=uCDF, cb=cb, title=title)

    plot.svGPFA.plotUtils.plotDifferenceCDFs(sUTRISIs=sUTRISIs, uCDF=uCDF, cb=cb)

    plot.svGPFA.plotUtils.plotScatter1Lag(x=utRISIs, title=title)

    acfRes, confint = statsmodels.tsa.stattools.acf(x=utRISIs, unbiased=True, alpha=0.05)
    plot.svGPFA.plotUtils.plotACF(acf=acfRes, Fs=1/dt, confint=confint, title=title)

    # ROC predictive analysis
    pk = cifValuesKS*dtCIF
    bins = pd.interval_range(start=0, end=T, periods=len(pk))
    cutRes, _ = pd.cut(spikesTimesKS, bins=bins, retbins=True)
    Y = torch.from_numpy(cutRes.value_counts().values)
    fpr, tpr, thresholds = sklearn.metrics.roc_curve(Y, pk, pos_label=1)
    roc_auc = sklearn.metrics.auc(fpr, tpr)
    plot.svGPFA.plotUtils.plotResROCAnalysis(fpr=fpr, tpr=tpr, auc=roc_auc, title=title)

    pdb.set_trace()
Example #7
0
def main(argv):

    parser = argparse.ArgumentParser()
    parser.add_argument("estInitNumber",
                        help="estimation init number",
                        type=int)
    parser.add_argument("--savePartial",
                        help="save partial model estimates",
                        action="store_true")
    parser.add_argument("--location",
                        help="location to analyze",
                        type=int,
                        default=0)
    parser.add_argument("--trials",
                        help="trials to analyze",
                        default="[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]")
    parser.add_argument("--nLatents",
                        help="number of latent variables",
                        type=int,
                        default=2)
    parser.add_argument("--from_time",
                        help="starting spike analysis time",
                        type=float,
                        default=0.750)
    parser.add_argument("--to_time",
                        help="ending spike analysis time",
                        type=float,
                        default=2.250)
    parser.add_argument("--unitsToRemove",
                        help="units to remove",
                        default="[39,57,73]")
    parser.add_argument("--save_partial_filename_pattern_pattern",
                        help="pattern for save partial model filename pattern",
                        default="results/{:08d}_{{:s}}_estimatedModel.pickle")
    parser.add_argument(
        "--data_filename",
        help="data filename",
        default="~/dev/research/gatsby-swc/datasets/george20040123_hnlds.mat")
    args = parser.parse_args()

    estInitNumber = args.estInitNumber
    save_partial = args.savePartial
    location = args.location
    trials = [int(str) for str in args.trials[1:-1].split(",")]
    nLatents = args.nLatents
    from_time = args.from_time
    to_time = args.to_time
    units_to_remove = [int(str) for str in args.unitsToRemove[1:-1].split(",")]
    save_partial_filename_pattern_pattern = args.save_partial_filename_pattern_pattern
    data_filename = args.data_filename

    mat = scipy.io.loadmat(os.path.expanduser(data_filename))
    spikesTimes = shenoyUtils.getTrialsAndLocationSpikesTimes(
        mat=mat, trials=trials, location=location)
    spikesTimes = shenoyUtils.clipSpikesTimes(spikes_times=spikesTimes,
                                              from_time=from_time,
                                              to_time=to_time)
    spikesTimes = shenoyUtils.removeUnits(spikes_times=spikesTimes,
                                          units_to_remove=units_to_remove)

    estInitConfigFilename = "data/{:08d}_estimation_metaData.ini".format(
        estInitNumber)
    estInitConfig = configparser.ConfigParser()
    estInitConfig.read(estInitConfigFilename)
    nQuad = int(estInitConfig["control_variables"]["nQuad"])
    kernelMatrixInvMethodStr = estInitConfig["control_variables"][
        "kernelMatrixInvMethod"]
    indPointsCovRepStr = estInitConfig["control_variables"]["indPointsCovRep"]
    if kernelMatrixInvMethodStr == "Chol":
        kernelMatrixInvMethod = stats.svGPFA.svGPFAModelFactory.kernelMatrixInvChol
    elif kernelMatrixInvMethodStr == "PInv":
        kernelMatrixInvMethod = stats.svGPFA.svGPFAModelFactory.kernelMatrixInvPInv
    else:
        raise RuntimeError("Invalid kernelMatrixInvMethod={:s}".format(
            kernelMatrixInvMethodStr))
    if indPointsCovRepStr == "Chol":
        indPointsCovRep = stats.svGPFA.svGPFAModelFactory.indPointsCovChol
    elif indPointsCovRepStr == "Rank1PlusDiag":
        indPointsCovRep = stats.svGPFA.svGPFAModelFactory.indPointsCovRank1PlusDiag
    else:
        raise RuntimeError(
            "Invalid indPointsCovRep={:s}".format(indPointsCovRepStr))
    indPointsLocsKMSRegEpsilon = float(
        estInitConfig["control_variables"]["indPointsLocsKMSRegEpsilon"])

    optimParamsConfig = estInitConfig._sections["optim_params"]
    optimMethod = optimParamsConfig["em_method"]
    optimParams = utils.svGPFA.miscUtils.getOptimParams(
        optimParamsDict=optimParamsDict)

    # load data and initial values
    # simResConfigFilename = "results/{:08d}_simulation_metaData.ini".format(simResNumber)
    # simResConfig = configparser.ConfigParser()
    # simResConfig.read(simResConfigFilename)
    # simInitConfigFilename = simResConfig["simulation_params"]["simInitConfigFilename"]
    # simResFilename = simResConfig["simulation_results"]["simResFilename"]

    # simInitConfig = configparser.ConfigParser()
    # simInitConfig.read(simInitConfigFilename)
    nNeurons = len(spikesTimes[0])
    nTrials = len(trials)
    trialsLengths = [to_time - from_time for i in range(nTrials)]

    # with open(simResFilename, "rb") as f: simRes = pickle.load(f)
    # spikesTimes = simRes["spikes"]

    randomEmbedding = estInitConfig["control_variables"][
        "randomEmbedding"].lower() == "true"
    if randomEmbedding:
        C0 = torch.rand(nNeurons, nLatents, dtype=torch.double).contiguous()
        d0 = torch.rand(nNeurons, 1, dtype=torch.double).contiguous()
    else:
        CFilename = estInitConfig["embedding_params"]["C_filename"]
        dFilename = estInitConfig["embedding_params"]["d_filename"]
        C, d = utils.svGPFA.configUtils.getLinearEmbeddingParams(
            CFilename=CFilename, dFilename=dFilename)
        initCondEmbeddingSTD = float(
            estInitConfig["control_variables"]["initCondEmbeddingSTD"])
        C0 = (C + torch.randn(C.shape) * initCondEmbeddingSTD).contiguous()
        d0 = (d + torch.randn(d.shape) * initCondEmbeddingSTD).contiguous()

    legQuadPoints, legQuadWeights = utils.svGPFA.miscUtils.getLegQuadPointsAndWeights(
        nQuad=nQuad, trialsLengths=trialsLengths)

    # kernels = utils.svGPFA.configUtils.getScaledKernels(nLatents=nLatents, config=estInitConfig, forceUnitScale=True)["kernels"]
    kernels = utils.svGPFA.configUtils.getKernels(nLatents=nLatents,
                                                  config=estInitConfig,
                                                  forceUnitScale=True)
    kernelsScaledParams0 = utils.svGPFA.initUtils.getKernelsScaledParams0(
        kernels=kernels, noiseSTD=0.0)
    Z0 = utils.svGPFA.configUtils.getIndPointsLocs0(nLatents=nLatents,
                                                    nTrials=nTrials,
                                                    config=estInitConfig)
    nIndPointsPerLatent = [Z0[k].shape[1] for k in range(nLatents)]

    qMu0 = utils.svGPFA.configUtils.getVariationalMean0(nLatents=nLatents,
                                                        nTrials=nTrials,
                                                        config=estInitConfig)
    #     indPointsMeans = utils.svGPFA.configUtils.getVariationalMean0(nLatents=nLatents, nTrials=nTrials, config=estInitConfig)
    #     # patch to acommodate Lea's equal number of inducing points across trials
    #     qMu0 = [[] for k in range(nLatents)]
    #     for k in range(nLatents):
    #         qMu0[k] = torch.empty((nTrials, nIndPointsPerLatent[k], 1), dtype=torch.double)
    #         for r in range(nTrials):
    #             qMu0[k][r,:,:] = indPointsMeans[k][r]
    #     # end patch

    qSigma0 = utils.svGPFA.configUtils.getVariationalCov0(nLatents=nLatents,
                                                          nTrials=nTrials,
                                                          config=estInitConfig)
    srQSigma0Vecs = utils.svGPFA.initUtils.getSRQSigmaVecsFromSRMatrices(
        srMatrices=qSigma0)
    qSVec0, qSDiag0 = utils.svGPFA.miscUtils.getQSVecsAndQSDiagsFromQSRSigmaVecs(
        srQSigmaVecs=srQSigma0Vecs)

    if indPointsCovRep == stats.svGPFA.svGPFAModelFactory.indPointsCovChol:
        qUParams0 = {"qMu0": qMu0, "srQSigma0Vecs": srQSigma0Vecs}
    elif indPointsCovRep == stats.svGPFA.svGPFAModelFactory.indPointsCovRank1PlusDiag:
        qUParams0 = {"qMu0": qMu0, "qSVec0": qSVec0, "qSDiag0": qSDiag0}
    else:
        raise RuntimeError("Invalid indPointsCovRep")

    kmsParams0 = {
        "kernelsParams0": kernelsScaledParams0,
        "inducingPointsLocs0": Z0
    }
    qKParams0 = {
        "svPosteriorOnIndPoints": qUParams0,
        "kernelsMatricesStore": kmsParams0
    }
    qHParams0 = {"C0": C0, "d0": d0}
    initialParams = {
        "svPosteriorOnLatents": qKParams0,
        "svEmbedding": qHParams0
    }
    quadParams = {
        "legQuadPoints": legQuadPoints,
        "legQuadWeights": legQuadWeights
    }

    estPrefixUsed = True
    while estPrefixUsed:
        estResNumber = random.randint(0, 10**8)
        estimResMetaDataFilename = "results/{:08d}_estimation_metaData.ini".format(
            estResNumber)
        if not os.path.exists(estimResMetaDataFilename):
            estPrefixUsed = False
    modelSaveFilename = "results/{:08d}_estimatedModel.pickle".format(
        estResNumber)
    save_partial_filename_pattern = save_partial_filename_pattern_pattern.format(
        estResNumber)

    kernelsTypes = [type(kernels[k]).__name__ for k in range(len(kernels))]
    estimationDataForMatlabFilename = "results/{:08d}_estimationDataForMatlab.mat".format(
        estResNumber)

    dt_latents = 0.01
    oneSetLatentsTrialTimes = torch.arange(from_time, to_time, dt_latents)
    latentsTrialsTimes = [oneSetLatentsTrialTimes for k in range(nLatents)]
    #     if "latentsTrialsTimes" in simRes.keys():
    #         latentsTrialsTimes = simRes["latentsTrialsTimes"]
    #     elif "times" in simRes.keys():
    #         latentsTrialsTimes = simRes["times"]
    #     else:
    #         raise ValueError("latentsTrialsTimes or times cannot be found in {:s}".format(simResFilename))
    utils.svGPFA.miscUtils.saveDataForMatlabEstimations(
        qMu=qMu0,
        qSVec=qSVec0,
        qSDiag=qSDiag0,
        C=C0,
        d=d0,
        indPointsLocs=Z0,
        legQuadPoints=legQuadPoints,
        legQuadWeights=legQuadWeights,
        kernelsTypes=kernelsTypes,
        kernelsParams=kernelsScaledParams0,
        spikesTimes=spikesTimes,
        indPointsLocsKMSRegEpsilon=indPointsLocsKMSRegEpsilon,
        trialsLengths=torch.tensor(trialsLengths).reshape(-1, 1),
        latentsTrialsTimes=latentsTrialsTimes,
        emMaxIter=optimParams["em_max_iter"],
        eStepMaxIter=optimParams["estep_optim_params"]["maxiter"],
        mStepEmbeddingMaxIter=optimParams["mstep_embedding_optim_params"]
        ["maxiter"],
        mStepKernelsMaxIter=optimParams["mstep_kernels_optim_params"]
        ["maxiter"],
        mStepIndPointsMaxIter=optimParams["mstep_indpointslocs_optim_params"]
        ["maxiter"],
        saveFilename=estimationDataForMatlabFilename)

    def getKernelParams(model):
        kernelParams = model.getKernelsParams()[0]
        return kernelParams

    # create model
    model = stats.svGPFA.svGPFAModelFactory.SVGPFAModelFactory.buildModelSciPy(
        conditionalDist=stats.svGPFA.svGPFAModelFactory.PointProcess,
        linkFunction=stats.svGPFA.svGPFAModelFactory.ExponentialLink,
        embeddingType=stats.svGPFA.svGPFAModelFactory.LinearEmbedding,
        kernels=kernels,
        kernelMatrixInvMethod=kernelMatrixInvMethod,
        indPointsCovRep=indPointsCovRep)

    model.setInitialParamsAndData(
        measurements=spikesTimes,
        initialParams=initialParams,
        eLLCalculationParams=quadParams,
        indPointsLocsKMSRegEpsilon=indPointsLocsKMSRegEpsilon)

    # save estimated values
    estimResConfig = configparser.ConfigParser()
    # estimResConfig["simulation_params"] = {"simResNumber": simResNumber}
    estimResConfig["data_params"] = {
        "data_filename": data_filename,
        "location": location,
        "trials": trials,
        "nLatents": nLatents,
        "from_time": from_time,
        "to_time": to_time
    }
    estimResConfig["optim_params"] = optimParams
    estimResConfig["estimation_params"] = {
        "estInitNumber": estInitNumber,
        "nIndPointsPerLatent": nIndPointsPerLatent
    }
    with open(estimResMetaDataFilename, "w") as f:
        estimResConfig.write(f)

    # maximize lower bound
    svEM = stats.svGPFA.svEM.SVEM_scipy()
    lowerBoundHist, elapsedTimeHist, terminationInfo, iterationsModelParams = \
            svEM.maximize(model=model, optimParams=optimParams,
                          method=optimMethod,
                          getIterationModelParamsFn=getKernelParams,
                          savePartial=save_partial,
                          savePartialFilenamePattern=save_partial_filename_pattern)

    resultsToSave = {
        "lowerBoundHist": lowerBoundHist,
        "elapsedTimeHist": elapsedTimeHist,
        "terminationInfo": terminationInfo,
        "iterationModelParams": iterationsModelParams,
        "model": model
    }
    with open(modelSaveFilename, "wb") as f:
        pickle.dump(resultsToSave, f)
    print("Saved results to {:s}".format(modelSaveFilename))

    pdb.set_trace()
def main(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument("simResNumber",
                        help="Simuluation result number",
                        type=int)
    parser.add_argument("estInitNumber",
                        help="estimation init number",
                        type=int)
    parser.add_argument("--nLatents",
                        help="number of latents to use in the estimation",
                        type=int,
                        default=-1)
    args = parser.parse_args()

    simResNumber = args.simResNumber
    estInitNumber = args.estInitNumber
    nLatents = args.nLatents

    # load data and initial values
    simResConfigFilename = "results/{:08d}_simulation_metaData.ini".format(
        simResNumber)
    simResConfig = configparser.ConfigParser()
    simResConfig.read(simResConfigFilename)
    simInitConfigFilename = simResConfig["simulation_params"][
        "simInitConfigFilename"]
    simResFilename = simResConfig["simulation_results"]["simResFilename"]

    simInitConfig = configparser.ConfigParser()
    simInitConfig.read(simInitConfigFilename)
    if nLatents < 0:
        nLatents = int(simInitConfig["control_variables"]["nLatents"])
    nNeurons = int(simInitConfig["control_variables"]["nNeurons"])
    trialsLengths = [
        float(str) for str in simInitConfig["control_variables"]
        ["trialsLengths"][1:-1].split(",")
    ]
    nTrials = len(trialsLengths)

    with open(simResFilename, "rb") as f:
        simRes = pickle.load(f)
    spikesTimes = simRes["spikes"]

    estInitConfigFilename = "data/{:08d}_estimation_metaData.ini".format(
        estInitNumber)
    estInitConfig = configparser.ConfigParser()
    estInitConfig.read(estInitConfigFilename)
    nIndPointsPerLatent = [
        int(str) for str in estInitConfig["control_variables"]
        ["nIndPointsPerLatent"][1:-1].split(",")
    ]
    nIndPointsPerLatent = nIndPointsPerLatent[:nLatents]
    nTestPoints = int(estInitConfig["control_variables"]["nTestPoints"])
    firstIndPointLoc = float(
        estInitConfig["control_variables"]["firstIndPointLoc"])
    randomEmbedding = estInitConfig["control_variables"][
        "randomEmbedding"].lower() == "true"
    initCondIndPointsScale = float(
        estInitConfig["control_variables"]["initCondIndPointsScale"])
    indPointsLocsKMSEpsilon = float(
        estInitConfig["control_variables"]["indPointsLocsKMSRegEpsilon"])
    nQuad = int(estInitConfig["control_variables"]["nQuad"])

    optimParamsConfig = estInitConfig._sections["optim_params"]
    optimParams = {}
    optimParams["emMaxIter"] = int(optimParamsConfig["emMaxIter".lower()])
    #
    optimParams["eStepEstimate"] = optimParamsConfig[
        "eStepEstimate".lower()] == "True"
    optimParams["eStepMaxIter"] = int(
        optimParamsConfig["eStepMaxIter".lower()])
    optimParams["eStepTol"] = float(optimParamsConfig["eStepTol".lower()])
    optimParams["eStepLR"] = float(optimParamsConfig["eStepLR".lower()])
    optimParams["eStepLineSearchFn"] = optimParamsConfig[
        "eStepLineSearchFn".lower()]
    optimParams["eStepNIterDisplay"] = int(
        optimParamsConfig["eStepNIterDisplay".lower()])
    #
    optimParams["mStepEmbeddingEstimate"] = optimParamsConfig[
        "mStepEmbeddingEstimate".lower()] == "True"
    optimParams["mStepEmbeddingMaxIter"] = int(
        optimParamsConfig["mStepEmbeddingMaxIter".lower()])
    optimParams["mStepEmbeddingTol"] = float(
        optimParamsConfig["mStepEmbeddingTol".lower()])
    optimParams["mStepEmbeddingLR"] = float(
        optimParamsConfig["mStepEmbeddingLR".lower()])
    optimParams["mStepEmbeddingLineSearchFn"] = optimParamsConfig[
        "mStepEmbeddingLineSearchFn".lower()]
    optimParams["mStepEmbeddingNIterDisplay"] = int(
        optimParamsConfig["mStepEmbeddingNIterDisplay".lower()])
    #
    optimParams["mStepKernelsEstimate"] = optimParamsConfig[
        "mStepKernelsEstimate".lower()] == "True"
    optimParams["mStepKernelsMaxIter"] = int(
        optimParamsConfig["mStepKernelsMaxIter".lower()])
    optimParams["mStepKernelsTol"] = float(
        optimParamsConfig["mStepKernelsTol".lower()])
    optimParams["mStepKernelsLR"] = float(
        optimParamsConfig["mStepKernelsLR".lower()])
    optimParams["mStepKernelsLineSearchFn"] = optimParamsConfig[
        "mStepKernelsLineSearchFn".lower()]
    optimParams["mStepKernelsNIterDisplay"] = int(
        optimParamsConfig["mStepKernelsNIterDisplay".lower()])
    #
    optimParams["mStepIndPointsEstimate"] = optimParamsConfig[
        "mStepIndPointsEstimate".lower()] == "True"
    optimParams["mStepIndPointsMaxIter"] = int(
        optimParamsConfig["mStepIndPointsMaxIter".lower()])
    optimParams["mStepIndPointsTol"] = float(
        optimParamsConfig["mStepIndPointsTol".lower()])
    optimParams["mStepIndPointsLR"] = float(
        optimParamsConfig["mStepIndPointsLR".lower()])
    optimParams["mStepIndPointsLineSearchFn"] = optimParamsConfig[
        "mStepIndPointsLineSearchFn".lower()]
    optimParams["mStepIndPointsNIterDisplay"] = int(
        optimParamsConfig["mStepIndPointsNIterDisplay".lower()])
    #
    optimParams["verbose"] = optimParamsConfig["verbose"] == "True"

    if randomEmbedding:
        C0 = torch.rand(nNeurons, nLatents, dtype=torch.double) - 0.5 * 2
        d0 = torch.rand(nNeurons, 1, dtype=torch.double) - 0.5 * 2
    else:
        CFilename = simInitConfig["embedding_params"]["C_filename"]
        dFilename = simInitConfig["embedding_params"]["d_filename"]
        C, d = utils.svGPFA.configUtils.getLinearEmbeddingParams(
            nNeurons=nNeurons,
            nLatents=nLatents,
            CFilename=CFilename,
            dFilename=dFilename)
        C0 = C + torch.randn(C.shape) * initCondEmbeddingSTD
        C0 = C0[:, :nLatents]
        d0 = d + torch.randn(d.shape) * initCondEmbeddingSTD

    legQuadPoints, legQuadWeights = utils.svGPFA.miscUtils.getLegQuadPointsAndWeights(
        nQuad=nQuad, trialsLengths=trialsLengths)

    kernels = utils.svGPFA.configUtils.getKernels(nLatents=nLatents,
                                                  config=estInitConfig)
    kernelsParams0 = utils.svGPFA.initUtils.getKernelsParams0(kernels=kernels,
                                                              noiseSTD=0.0)

    qMu0, qSVec0, qSDiag0 = utils.svGPFA.initUtils.getSVPosteriorOnIndPointsParams0(
        nIndPointsPerLatent=nIndPointsPerLatent,
        nLatents=nLatents,
        nTrials=nTrials,
        scale=initCondIndPointsScale)

    Z0 = utils.svGPFA.initUtils.getIndPointLocs0(
        nIndPointsPerLatent=nIndPointsPerLatent,
        trialsLengths=trialsLengths,
        firstIndPointLoc=firstIndPointLoc)
    qUParams0 = {"qMu0": qMu0, "qSVec0": qSVec0, "qSDiag0": qSDiag0}
    kmsParams0 = {"kernelsParams0": kernelsParams0, "inducingPointsLocs0": Z0}
    qKParams0 = {
        "svPosteriorOnIndPoints": qUParams0,
        "kernelsMatricesStore": kmsParams0
    }
    qHParams0 = {"C0": C0, "d0": d0}
    initialParams = {
        "svPosteriorOnLatents": qKParams0,
        "svEmbedding": qHParams0
    }
    quadParams = {
        "legQuadPoints": legQuadPoints,
        "legQuadWeights": legQuadWeights
    }

    estPrefixUsed = True
    while estPrefixUsed:
        estResNumber = random.randint(0, 10**8)
        estimResMetaDataFilename = "results/{:08d}_estimation_metaData.ini".format(
            estResNumber)
        if not os.path.exists(estimResMetaDataFilename):
            estPrefixUsed = False
    modelSaveFilename = "results/{:08d}_estimatedModel.pickle".format(
        estResNumber)

    kernelsTypes = [type(kernels[k]).__name__ for k in range(len(kernels))]
    estimationDataForMatlabFilename = "results/{:08d}_estimationDataForMatlab.mat".format(
        estResNumber)
    utils.svGPFA.miscUtils.saveDataForMatlabEstimations(
        qMu0=qMu0,
        qSVec0=qSVec0,
        qSDiag0=qSDiag0,
        C0=C0,
        d0=d0,
        indPointsLocs0=Z0,
        legQuadPoints=legQuadPoints,
        legQuadWeights=legQuadWeights,
        kernelsTypes=kernelsTypes,
        kernelsParams0=kernelsParams0,
        spikesTimes=spikesTimes,
        indPointsLocsKMSEpsilon=indPointsLocsKMSEpsilon,
        trialsLengths=np.array(trialsLengths).reshape(-1, 1),
        emMaxIter=optimParams["emMaxIter"],
        eStepMaxIter=optimParams["eStepMaxIter"],
        mStepEmbeddingMaxIter=optimParams["mStepEmbeddingMaxIter"],
        mStepKernelsMaxIter=optimParams["mStepKernelsMaxIter"],
        mStepIndPointsMaxIter=optimParams["mStepIndPointsMaxIter"],
        saveFilename=estimationDataForMatlabFilename)

    # create model
    model = stats.svGPFA.svGPFAModelFactory.SVGPFAModelFactory.buildModel(
        conditionalDist=stats.svGPFA.svGPFAModelFactory.PointProcess,
        linkFunction=stats.svGPFA.svGPFAModelFactory.ExponentialLink,
        embeddingType=stats.svGPFA.svGPFAModelFactory.LinearEmbedding,
        kernels=kernels)

    # maximize lower bound
    svEM = stats.svGPFA.svEM.SVEM()
    lowerBoundHist, elapsedTimeHist = svEM.maximize(
        model=model,
        measurements=spikesTimes,
        initialParams=initialParams,
        quadParams=quadParams,
        optimParams=optimParams,
        indPointsLocsKMSEpsilon=indPointsLocsKMSEpsilon)

    # save estimated values
    estimResConfig = configparser.ConfigParser()
    estimResConfig["simulation_params"] = {"simResNumber": simResNumber}
    estimResConfig["optim_params"] = optimParams
    estimResConfig["estimation_params"] = {
        "estInitNumber": estInitNumber,
        "nIndPointsPerLatent": nIndPointsPerLatent
    }
    with open(estimResMetaDataFilename, "w") as f:
        estimResConfig.write(f)

    resultsToSave = {
        "lowerBoundHist": lowerBoundHist,
        "elapsedTimeHist": elapsedTimeHist,
        "model": model
    }
    with open(modelSaveFilename, "wb") as f:
        pickle.dump(resultsToSave, f)

    pdb.set_trace()
def main(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument("--profile",
                        help="perform profiling",
                        action="store_true")
    args = parser.parse_args()
    if args.profile:
        profile = True
    else:
        profile = False

    tol = 1e-3
    ppSimulationFilename = os.path.join(os.path.dirname(__file__),
                                        "data/pointProcessSimulation.mat")
    initDataFilename = os.path.join(os.path.dirname(__file__),
                                    "data/pointProcessInitialConditions.mat")
    lowerBoundHistFigFilename = "figures/leasLowerBoundHist_{:s}.png".format(
        "cpu")
    modelSaveFilename = "results/estimationResLeasSimulation_{:s}.pickle".format(
        "cpu")
    profilerFilenamePattern = "results/demoPointProcessLeasSimulation_{:d}Iter.pstats"

    mat = loadmat(initDataFilename)
    nLatents = len(mat['Z0'])
    nTrials = mat['Z0'][0, 0].shape[2]
    qMu0 = [
        torch.from_numpy(mat['q_mu0'][(0, i)]).type(
            torch.DoubleTensor).permute(2, 0, 1) for i in range(nLatents)
    ]
    qSVec0 = [
        torch.from_numpy(mat['q_sqrt0'][(0, i)]).type(
            torch.DoubleTensor).permute(2, 0, 1) for i in range(nLatents)
    ]
    qSDiag0 = [
        torch.from_numpy(mat['q_diag0'][(0, i)]).type(
            torch.DoubleTensor).permute(2, 0, 1) for i in range(nLatents)
    ]
    Z0 = [
        torch.from_numpy(mat['Z0'][(i, 0)]).type(torch.DoubleTensor).permute(
            2, 0, 1) for i in range(nLatents)
    ]
    C0 = torch.from_numpy(mat["C0"]).type(torch.DoubleTensor)
    b0 = torch.from_numpy(mat["b0"]).type(torch.DoubleTensor).squeeze()
    legQuadPoints = torch.from_numpy(mat['ttQuad']).type(
        torch.DoubleTensor).permute(2, 0, 1)
    legQuadWeights = torch.from_numpy(mat['wwQuad']).type(
        torch.DoubleTensor).permute(2, 0, 1)

    yMat = loadmat(ppSimulationFilename)
    YNonStacked_tmp = yMat['Y']
    nNeurons = YNonStacked_tmp[0, 0].shape[0]
    YNonStacked = [[[] for n in range(nNeurons)] for r in range(nTrials)]
    for r in range(nTrials):
        for n in range(nNeurons):
            YNonStacked[r][n] = torch.from_numpy(
                YNonStacked_tmp[r, 0][n, 0][:, 0]).type(torch.DoubleTensor)

    kernelNames = mat["kernelNames"]
    hprs0 = mat["hprs0"]
    indPointsLocsKMSEpsilon = 1e-4

    # create kernels
    kernels = [[None] for k in range(nLatents)]
    for k in range(nLatents):
        if np.char.equal(kernelNames[0, k][0], "PeriodicKernel"):
            kernels[k] = stats.kernels.PeriodicKernel()
        elif np.char.equal(kernelNames[0, k][0], "rbfKernel"):
            kernels[k] = stats.kernels.ExponentialQuadraticKernel()
        else:
            raise ValueError("Invalid kernel name: %s" % (kernelNames[k]))

    # create initial parameters
    kernelsParams0 = [[None] for k in range(nLatents)]
    for k in range(nLatents):
        if np.char.equal(kernelNames[0, k][0], "PeriodicKernel"):
            kernelsParams0[k] = torch.tensor(
                [1.0, float(hprs0[k, 0][0]),
                 float(hprs0[k, 0][1])],
                dtype=torch.double)
        elif np.char.equal(kernelNames[0, k][0], "rbfKernel"):
            kernelsParams0[k] = torch.tensor([1.0, float(hprs0[k, 0][0])],
                                             dtype=torch.double)
        else:
            raise ValueError("Invalid kernel name: %s" % (kernelNames[k]))

    qUParams0 = {"qMu0": qMu0, "qSVec0": qSVec0, "qSDiag0": qSDiag0}
    qHParams0 = {"C0": C0, "d0": b0}
    kmsParams0 = {"kernelsParams0": kernelsParams0, "inducingPointsLocs0": Z0}
    initialParams = {
        "svPosteriorOnIndPoints": qUParams0,
        "kernelsMatricesStore": kmsParams0,
        "svEmbedding": qHParams0
    }
    quadParams = {
        "legQuadPoints": legQuadPoints,
        "legQuadWeights": legQuadWeights
    }
    optimParams = {
        "emMaxNIter": 5,
        "eStepMaxNIter": 100,
        "mStepModelParamsMaxNIter": 100,
        "mStepKernelParamsMaxNIter": 20,
        "mStepIndPointsMaxNIter": 10,
        "mStepIndPointsLR": 1e-2
    }

    model = stats.svGPFA.svGPFAModelFactory.SVGPFAModelFactory.buildModel(
        conditionalDist=stats.svGPFA.svGPFAModelFactory.PointProcess,
        linkFunction=stats.svGPFA.svGPFAModelFactory.ExponentialLink,
        embeddingType=stats.svGPFA.svGPFAModelFactory.LinearEmbedding,
        kernels=kernels,
        indPointsLocsKMSEpsilon=indPointsLocsKMSEpsilon)

    # start debug code
    # parametersList = []
    # i = 0
    # for parameter in model.parameters():
    #     print("Inside for loop")
    #     print(i, parameter)
    #     parametersList.append(parameter)
    # print("Outside for loop")
    # pdb.set_trace()
    # ned debug code

    # maximize lower bound
    svEM = stats.svGPFA.svEM.SVEM()
    if profile:
        pr = cProfile.Profile()
        pr.enable()
    tStart = time.time()
    lowerBoundHist, elapsedTimeHist = \
        svEM.maximize(model=model,
                      measurements=YNonStacked,
                      initialParams=initialParams,
                      quadParams=quadParams,
                      optimParams=optimParams)
    tElapsed = time.time() - tStart
    print("Completed maximize in {:.2f} seconds".format(tElapsed))

    # start debug code
    # parametersList = []
    # i = 0
    # for parameter in model.parameters():
    #     print("Inside for loop")
    #     print(i, parameter)
    #     parametersList.append(parameter)
    #     i += 1
    # print("Outside for loop")
    # pdb.set_trace()
    # end debug code

    if profile:
        pr.disable()
        profilerFilename = profilerFilenamePattern.format(
            optimParams["emMaxNIter"])
        s = open(profilerFilename, "w")
        sortby = "cumulative"
        ps = pstats.Stats(pr, stream=s)
        ps.strip_dirs().sort_stats(sortby).print_stats()
        s.close()

    resultsToSave = {
        "lowerBoundHist": lowerBoundHist,
        "elapsedTimeHist": elapsedTimeHist,
        "model": model
    }
    with open(modelSaveFilename, "wb") as f:
        pickle.dump(resultsToSave, f)

    # plot lower bound history
    plot.svGPFA.plotUtils.plotLowerBoundHist(
        lowerBoundHist=lowerBoundHist,
        elapsedTimeHist=elapsedTimeHist,
        figFilename=lowerBoundHistFigFilename)

    pdb.set_trace()
Example #10
0
def main(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument("mEstNumber", help="Matlab's estimation number", type=int)
    parser.add_argument("--deviceName", help="name of device (cpu or cuda)", default="cpu")
    parser.add_argument("--profile", help="perform profiling", action="store_true")
    args = parser.parse_args()
    if args.profile:
        profile = True
    else:
        profile = False

    mEstNumber = args.mEstNumber
    deviceName = args.deviceName
    if not torch.cuda.is_available():
        deviceName = "cpu"
    device = torch.device(deviceName)
    print("Using {:s}".format(deviceName))

    mEstConfig = configparser.ConfigParser()
    mEstConfig.read("../../matlabCode/scripts/results/{:08d}-pointProcessEstimationParams.ini".format(mEstNumber))
    mSimNumber = int(mEstConfig["data"]["simulationNumber"])
    indPointsLocsKMSEpsilon = float(mEstConfig["control_variables"]["epsilon"])
    ppSimulationFilename = os.path.join(os.path.dirname(__file__), "../../matlabCode/scripts/results/{:08d}-pointProcessSimulation.mat".format(mSimNumber))
    initDataFilename = os.path.join(os.path.dirname(__file__), "../../matlabCode/scripts/results/{:08d}-pointProcessInitialConditions.mat".format(mEstNumber))

    # save estimated values
    estimationPrefixUsed = True
    while estimationPrefixUsed:
        pEstNumber = random.randint(0, 10**8)
        estimMetaDataFilename = \
                "results/{:08d}_leasSimulation_estimationChol_metaData_{:s}.ini".format(pEstNumber, deviceName)
        if not os.path.exists(estimMetaDataFilename):
           estimationPrefixUsed = False
    modelSaveFilename = \
        "results/{:08d}_leasSimulation_estimatedModelChol_{:s}.pickle".format(pEstNumber, deviceName)
    profilerFilenamePattern = \
        "results/{:08d}_leaseSimulation_estimatedModelChol_{:s}.pstats".format(pEstNumber, deviceName)
    lowerBoundHistFigFilename = \
        "figures/{:08d}_leasSimulation_lowerBoundHistChol_{:s}.png".format(pEstNumber, deviceName)

    mat = scipy.io.loadmat(initDataFilename)
    nLatents = len(mat['Z0'])
    nTrials = mat['Z0'][0,0].shape[2]
    qMu0 = [torch.from_numpy(mat['q_mu0'][(0,k)]).type(torch.DoubleTensor).permute(2,0,1).to(device) for k in range(nLatents)]
    qSVec0 = [torch.from_numpy(mat['q_sqrt0'][(0,k)]).type(torch.DoubleTensor).permute(2,0,1).to(device) for k in range(nLatents)]
    qSDiag0 = [torch.from_numpy(mat['q_diag0'][(0,k)]).type(torch.DoubleTensor).permute(2,0,1).to(device) for k in range(nLatents)]
    Z0 = [torch.from_numpy(mat['Z0'][(k,0)]).type(torch.DoubleTensor).permute(2,0,1).to(device) for k in range(nLatents)]
    C0 = torch.from_numpy(mat["C0"]).type(torch.DoubleTensor).to(device)
    b0 = torch.from_numpy(mat["b0"]).type(torch.DoubleTensor).squeeze().to(device)
    legQuadPoints = torch.from_numpy(mat['ttQuad']).type(torch.DoubleTensor).permute(2, 0, 1).to(device)
    legQuadWeights = torch.from_numpy(mat['wwQuad']).type(torch.DoubleTensor).permute(2, 0, 1).to(device)

    # qSigma0[k] \in nTrials x nInd[k] x nInd[k]
    qSigma0 = utils.svGPFA.initUtils.buildQSigmaFromQSVecAndQSDiag(qSVec=qSVec0, qSDiag=qSDiag0)
    qSRSigma0 = [[None] for k in range(nLatents)]
    for k in range(nLatents):
        nIndPointsK = qSigma0[k].shape[1]
        qSRSigma0[k] = torch.empty((nTrials, nIndPointsK, nIndPointsK), dtype=torch.double)
        for r in range(nTrials):
            qSRSigma0[k][r,:,:] = torch.cholesky(qSigma0[k][r,:,:])

    yMat = loadmat(ppSimulationFilename)
    YNonStacked_tmp = yMat['Y']
    nNeurons = YNonStacked_tmp[0,0].shape[0]
    YNonStacked = [[[] for n in range(nNeurons)] for r in range(nTrials)]
    for r in range(nTrials):
        for n in range(nNeurons):
            spikesTrialNeuron = YNonStacked_tmp[r,0][n,0]
            if len(spikesTrialNeuron)>0:
                YNonStacked[r][n] = torch.from_numpy(spikesTrialNeuron[:,0]).type(torch.DoubleTensor).to(device)
            else:
                YNonStacked[r][n] = []

    kernelNames = mat["kernelNames"]
    hprs0 = mat["hprs0"]

    # create kernels
    kernels = [[None] for k in range(nLatents)]
    for k in range(nLatents):
        if np.char.equal(kernelNames[0,k][0], "PeriodicKernel"):
            kernels[k] = stats.kernels.PeriodicKernel(scale=1.0)
        elif np.char.equal(kernelNames[0,k][0], "rbfKernel"):
            kernels[k] = stats.kernels.ExponentialQuadraticKernel(scale=1.0)
        else:
            raise ValueError("Invalid kernel name: %s"%(kernelNames[k]))

    # create initial parameters
    kernelsParams0 = [[None] for k in range(nLatents)]
    for k in range(nLatents):
        if np.char.equal(kernelNames[0,k][0], "PeriodicKernel"):
            kernelsParams0[k] = torch.tensor([float(hprs0[k,0][0]),
                                              float(hprs0[k,0][1])],
                                             dtype=torch.double).to(device)
        elif np.char.equal(kernelNames[0,k][0], "rbfKernel"):
            kernelsParams0[k] = torch.tensor([float(hprs0[k,0][0])],
                                              dtype=torch.double).to(device)
        else:
            raise ValueError("Invalid kernel name: %s"%(kernelNames[k]))

    qUParams0 = {"qMu0": qMu0, "qSRSigma0": qSRSigma0}
    kmsParams0 = {"kernelsParams0": kernelsParams0,
                  "inducingPointsLocs0": Z0}
    qKParams0 = {"svPosteriorOnIndPoints": qUParams0,
                 "kernelsMatricesStore": kmsParams0}
    qHParams0 = {"C0": C0, "d0": b0}
    initialParams = {"svPosteriorOnLatents": qKParams0,
                     "svEmbedding": qHParams0}
    quadParams = {"legQuadPoints": legQuadPoints,
                  "legQuadWeights": legQuadWeights}
    optimParams = {"emMaxIter":50,
                   #
                   "eStepEstimate":True,
                   "eStepMaxIter":100,
                   "eStepTol":1e-3,
                   "eStepLR":1e-3,
                   "eStepLineSearchFn":"strong_wolfe",
                   # "eStepLineSearchFn":"None",
                   "eStepNIterDisplay":1,
                   #
                   "mStepEmbeddingEstimate":True,
                   "mStepEmbeddingMaxIter":100,
                   "mStepEmbeddingTol":1e-3,
                   "mStepEmbeddingLR":1e-3,
                   "mStepEmbeddingLineSearchFn":"strong_wolfe",
                   # "mStepEmbeddingLineSearchFn":"None",
                   "mStepEmbeddingNIterDisplay":1,
                   #
                   "mStepKernelsEstimate":True,
                   "mStepKernelsMaxIter":10,
                   "mStepKernelsTol":1e-3,
                   "mStepKernelsLR":1e-3,
                   "mStepKernelsLineSearchFn":"strong_wolfe",
                   # "mStepKernelsLineSearchFn":"None",
                   "mStepKernelsNIterDisplay":1,
                   "mStepKernelsNIterDisplay":1,
                   #
                   "mStepIndPointsEstimate":True,
                   "mStepIndPointsMaxIter":20,
                   "mStepIndPointsTol":1e-3,
                   "mStepIndPointsLR":1e-4,
                   "mStepIndPointsLineSearchFn":"strong_wolfe",
                   # "mStepIndPointsLineSearchFn":"None",
                   "mStepIndPointsNIterDisplay":1,
                   #
                   "verbose":True
                  }
    estimConfig = configparser.ConfigParser()
    estimConfig["data"] = {"mEstNumber": mEstNumber}
    estimConfig["optim_params"] = optimParams
    estimConfig["control_params"] = {"indPointsLocsKMSEpsilon": indPointsLocsKMSEpsilon}
    with open(estimMetaDataFilename, "w") as f: estimConfig.write(f)

    trialsLengths = yMat["trLen"].astype(np.float64).flatten().tolist()
    kernelsTypes = [type(kernels[k]).__name__ for k in range(len(kernels))]
    estimationDataForMatlabFilename = "results/{:08d}_estimationDataForMatlab.mat".format(0)
    # estimationDataForMatlabFilename = "results/{:08d}_estimationDataForMatlab.mat".format(estResNumber)
#     utils.svGPFA.miscUtils.saveDataForMatlabEstimations(
#         qMu0=qMu0, qSVec0=qSVec0, qSDiag0=qSDiag0,
#         C0=C0, d0=b0,
#         indPointsLocs0=Z0,
#         legQuadPoints=legQuadPoints,
#         legQuadWeights=legQuadWeights,
#         kernelsTypes=kernelsTypes,
#         kernelsParams0=kernelsParams0,
#         spikesTimes=YNonStacked,
#         indPointsLocsKMSEpsilon=indPointsLocsKMSEpsilon,
#         trialsLengths=np.array(trialsLengths).reshape(-1,1),
#         emMaxIter=optimParams["emMaxIter"],
#         eStepMaxIter=optimParams["eStepMaxIter"],
#         mStepEmbeddingMaxIter=optimParams["mStepEmbeddingMaxIter"],
#         mStepKernelsMaxIter=optimParams["mStepKernelsMaxIter"],
#         mStepIndPointsMaxIter=optimParams["mStepIndPointsMaxIter"],
#         saveFilename=estimationDataForMatlabFilename)

    model = stats.svGPFA.svGPFAModelFactory.SVGPFAModelFactory.buildModel(
        conditionalDist=stats.svGPFA.svGPFAModelFactory.PointProcess,
        linkFunction=stats.svGPFA.svGPFAModelFactory.ExponentialLink,
        embeddingType=stats.svGPFA.svGPFAModelFactory.LinearEmbedding,
        kernels=kernels,
    )

    # start debug code
    # parametersList = []
    # i = 0
    # for parameter in model.parameters():
    #     print("Inside for loop")
    #     print(i, parameter)
    #     parametersList.append(parameter)
    # print("Outside for loop")
    # pdb.set_trace()
    # ned debug code

    # model.to(device)

    # maximize lower bound
    svEM = stats.svGPFA.svEM.SVEM()
    if profile:
        pr = cProfile.Profile()
        pr.enable()
    tStart = time.time()
    lowerBoundHist, elapsedTimeHist = \
        svEM.maximize(model=model,
                      measurements=YNonStacked,
                      initialParams=initialParams,
                      quadParams=quadParams,
                      optimParams=optimParams,
                      indPointsLocsKMSEpsilon=indPointsLocsKMSEpsilon,
                     )
    tElapsed = time.time()-tStart
    print("Completed maximize in {:.2f} seconds".format(tElapsed))

    # start debug code
    # parametersList = []
    # i = 0
    # for parameter in model.parameters():
    #     print("Inside for loop")
    #     print(i, parameter)
    #     parametersList.append(parameter)
    #     i += 1
    # print("Outside for loop")
    # pdb.set_trace()
    # end debug code

    if profile:
        pr.disable()
        profilerFilename = profilerFilenamePattern.format(optimParams["emMaxIter"])
        s = open(profilerFilename, "w")
        sortby = "cumulative"
        ps = pstats.Stats(pr, stream=s)
        ps.strip_dirs().sort_stats(sortby).print_stats()
        s.close()

    resultsToSave = {"lowerBoundHist": lowerBoundHist, "elapsedTimeHist": elapsedTimeHist, "model": model}
    with open(modelSaveFilename, "wb") as f: pickle.dump(resultsToSave, f)

    # plot lower bound history
    plot.svGPFA.plotUtils.plotLowerBoundHist(lowerBoundHist=lowerBoundHist, elapsedTimeHist=elapsedTimeHist, figFilename=lowerBoundHistFigFilename)

    pdb.set_trace()