def PlotData(dataDf, feature='PSA', titleStr="", drugBarPosition=0.85,
             xlim=2e3, ylim=1.3, y2lim=1, decorateX=True, decorateY=True, decoratey2=False, markersize=10,
             ax=None, figsize=(10, 8), outName=None, **kwargs):
    if ax is None: fig, ax = plt.subplots(1, 1, figsize=figsize)
    # Plot the data
    ax.plot(dataDf.Time, dataDf[feature],
            linestyle="None", marker="x", markersize=markersize,
            color="black", markeredgewidth=2)

    # Plot the drug concentration
    ax2 = ax.twinx()  # instantiate a second axes that shares the same x-axis
    drugConcentrationVec = utils.TreatmentListToTS(treatmentList=utils.ExtractTreatmentFromDf(dataDf),
                                                   tVec=dataDf['Time'])
    drugConcentrationVec = drugConcentrationVec / (1 - drugBarPosition) + drugBarPosition
    ax2.fill_between(dataDf['Time'], drugBarPosition, drugConcentrationVec,
                     step="post", color="black", alpha=1., label="Drug Concentration")
    ax2.axis("off")

    # Format the plot
    ax.set_xlim(0,xlim)
    ax.set_ylim(0, ylim)
    ax2.set_ylim([0, y2lim])
    ax.set_xlabel("")
    ax.set_ylabel("")
    ax.set_title(titleStr)
    ax.tick_params(labelsize=28)
    ax2.tick_params(labelsize=28)
    ax.legend().remove()
    if not decorateX:
        ax.set_xticklabels("")
    if not decorateY:
        ax.set_yticklabels("")
    plt.tight_layout()
    if outName is not None: plt.savefig(outName)
Ejemplo n.º 2
0
def FitModel(job):
    patientId, fitId, params, outDir = job['patientId'], job['fitId'], job[
        'params'], job['outDir']
    dataDf = LoadPatientData(patientId, dataDir)
    summaryOutDir = os.path.join(outDir, "patient%d/" % (patientId))
    modelOutDir = os.path.join(summaryOutDir, "fitId%d/" % (fitId))
    job['outDir'] = modelOutDir
    if os.path.isfile(
            os.path.join(summaryOutDir,
                         "fitObj_patient_%d_fit_%d.p" % (patientId, fitId))):
        return 0
    utils.mkdir(modelOutDir)
    seed = int.from_bytes(os.urandom(4), byteorder='little')
    np.random.seed(seed)
    tmpModel = OnLatticeModel()
    tmpModel.SetParams(**job,
                       **solver_kws)  # modelConfigDic['outDir'] = currOutDir
    if perturbICs: params = PerturbParams(params)
    try:
        fitObj = minimize(residual,
                          params,
                          args=(0, dataDf, eps_data, tmpModel, "PSA",
                                solver_kws),
                          **optimiser_kws)
        # Plot best fit
        myModel = OnLatticeModel()
        myModel.SetParams(**fitObj.params.valuesdict(), **solver_kws)
        myModel.SetParams(outDir=modelOutDir)
        myModel.Simulate(
            treatmentScheduleList=utils.ExtractTreatmentFromDf(dataDf),
            max_step=1,
            **solver_kws)
        fig, ax = plt.subplots(1, 1, figsize=(10, 6))
        plt.plot(dataDf.Time, dataDf.PSA, linestyle='none', marker='x')
        myModel.Plot(ylim=2., ax=ax)
        plt.savefig(
            os.path.join(summaryOutDir,
                         "patient_%d_fit_%d.png" % (patientId, fitId)))
        plt.close()

        # Save fit
        fitObj.patientId = patientId
        fitObj.fitId = fitId
        fitObj.seed = seed
        fitObj.eps_data = eps_data
        fitObj.rSq = ComputeRSquared(fitObj, dataDf)
        pickle.dump(obj=fitObj,
                    file=open(
                        os.path.join(
                            summaryOutDir,
                            "fitObj_patient_%d_fit_%d.p" % (patientId, fitId)),
                        "wb"))
        shutil.rmtree(modelOutDir)
    except:
        pass
def residual(params, x, data, eps_data, model, feature="PSA", solver_kws={}):
    model.SetParams(**params.valuesdict())
    model.Simulate(treatmentScheduleList=utils.ExtractTreatmentFromDf(data),
                   **solver_kws)
    # Interpolate to the data time grid
    t_eval = data.Time
    f = scipy.interpolate.interp1d(model.resultsDf.Time,
                                   model.resultsDf.TumourSize,
                                   fill_value="extrapolate")
    modelPrediction = f(t_eval)
    return (data[feature].values - modelPrediction) / eps_data
def SimulateFit(fitObj, dataDf, dt=1, solver_kws={}):
    myModel = lvm.LotkaVolterraModel()
    myModel.SetParams(**fitObj.params.valuesdict())
    myModel.Simulate(treatmentScheduleList=utils.ExtractTreatmentFromDf(dataDf),**solver_kws)
    # Interpolate to the desired time grid
    t_eval = np.arange(0, myModel.resultsDf.Time.max(), dt)
    trimmedResultsDic = {'Time': t_eval}
    for variable in ['S', 'R', 'TumourSize', 'DrugConcentration']:
        f = scipy.interpolate.interp1d(myModel.resultsDf.Time, myModel.resultsDf[variable])
        trimmedResultsDic = {**trimmedResultsDic, variable: f(t_eval)}
    myModel.resultsDf = pd.DataFrame(trimmedResultsDic)
    return myModel
def residual(params, x, data, eps_data, model, feature="TumourSize",solver_kws={}):
    model.SetParams(**params.valuesdict())
    converged = False
    max_step = solver_kws.get('max_step',np.inf)
    currSolver_kws = solver_kws.copy()
    while not converged:
        model.Simulate(treatmentScheduleList=utils.ExtractTreatmentFromDf(data), **currSolver_kws)
        converged = model.successB
        max_step = 0.75*max_step if max_step < np.inf else 100
        currSolver_kws['max_step'] = max_step
    # Interpolate to the data time grid
    t_eval = data.Time
    f = scipy.interpolate.interp1d(model.resultsDf.Time,model.resultsDf.TumourSize,fill_value="extrapolate")
    modelPrediction = f(t_eval)
    return (data[feature]-modelPrediction) / eps_data
def SimulateFit(fitObj,
                dataDf,
                trim=True,
                dt=1,
                saveFiles=False,
                solver_kws={}):
    myModel = OnLatticeModel()
    solver_kws = solver_kws.copy()
    solver_kws['outDir'] = solver_kws.get(
        'outDir', "./tmp/patient%d/fit%d/" % (fitObj.patientId, fitObj.fitId))
    myModel.SetParams(**fitObj.params.valuesdict(), **solver_kws)
    myModel.Simulate(
        treatmentScheduleList=utils.ExtractTreatmentFromDf(dataDf),
        **solver_kws)
    myModel.resultsDf = myModel.LoadSimulations()
    myModel.NormaliseToInitialSize(myModel.resultsDf)
    # Interpolate to the desired time grid
    if trim:
        t_eval = np.arange(0, myModel.resultsDf.Time.max(), dt)
        tmpDfList = []
        for replicateId in myModel.resultsDf.ReplicateId.unique():
            trimmedResultsDic = {
                'Time': t_eval,
                'ReplicateId': replicateId * np.ones_like(t_eval)
            }
            for variable in ['S', 'R', 'TumourSize', 'DrugConcentration']:
                f = scipy.interpolate.interp1d(
                    myModel.resultsDf.Time[myModel.resultsDf.ReplicateId ==
                                           replicateId],
                    myModel.resultsDf.loc[myModel.resultsDf.ReplicateId ==
                                          replicateId, variable])
                trimmedResultsDic = {**trimmedResultsDic, variable: f(t_eval)}
            tmpDfList.append(pd.DataFrame(trimmedResultsDic))
        myModel.resultsDf = pd.concat(tmpDfList)
    if not saveFiles: shutil.rmtree(solver_kws['outDir'])
    return myModel
    def Plot(self, decoratey2=True, ax=None, **kwargs):
        if ax is None: fig, ax = plt.subplots(1, 1)
        lnslist = []
        # Plot the area the we will see on the images
        if kwargs.get('plotAreaB', True):
            lnslist += ax.plot(self.resultsDf['Time'],
                               self.resultsDf['TumourSize'],
                               lw=kwargs.get('linewidthA', 4),
                               color=kwargs.get('colorA', 'b'),
                               linestyle=kwargs.get('linestyleA', '-'),
                               marker=kwargs.get('markerA', None),
                               label=kwargs.get('labelA', 'Model Prediction'))

        # Plot the individual populations
        if kwargs.get('plotPops', False):
            propS = self.resultsDf['S'].values / (self.resultsDf['S'].values +
                                                  self.resultsDf['R'].values)
            lnslist += ax.plot(self.resultsDf['Time'],
                               propS * self.resultsDf['TumourSize'],
                               lw=kwargs.get('linewidth', 4),
                               linestyle=kwargs.get('linestyleS', '--'),
                               color=kwargs.get('colorS', 'g'),
                               label='S')
            lnslist += ax.plot(self.resultsDf['Time'],
                               (1 - propS) * self.resultsDf['TumourSize'],
                               lw=kwargs.get('linewidth', 4),
                               linestyle=kwargs.get('linestyleR', '--'),
                               color=kwargs.get('colorR', 'r'),
                               label='R')

            # Plot the drug concentration
        ax2 = ax.twinx(
        )  # instantiate a second axes that shares the same x-axis
        drugConcentrationVec = utils.TreatmentListToTS(
            treatmentList=utils.ExtractTreatmentFromDf(self.resultsDf),
            tVec=self.resultsDf['Time'])
        ax2.fill_between(self.resultsDf['Time'],
                         0,
                         drugConcentrationVec,
                         color="#8f59e0",
                         alpha=0.2,
                         label="Drug Concentration")
        # Format the plot
        ax.set_xlim(
            [0, kwargs.get('xlim', 1.1 * self.resultsDf['Time'].max())])
        ax.set_ylim([
            kwargs.get('yMin',
                       -1.1 * np.abs(self.resultsDf['TumourSize'].min())),
            kwargs.get('ylim', 1.1 * self.resultsDf['TumourSize'].max())
        ])
        ax2.set_ylim([
            0,
            kwargs.get('y2lim', self.resultsDf['DrugConcentration'].max() + .1)
        ])
        ax.set_xlabel("Time")
        ax.set_ylabel("Tumour Size")
        ax2.set_ylabel(r"Drug Concentration in $\mu M$" if decoratey2 else "")
        ax.set_title(kwargs.get('title', ''))
        if kwargs.get('plotLegendB', True):
            labsList = [l.get_label() for l in lnslist]
            plt.legend(lnslist,
                       labsList,
                       loc=kwargs.get('legendLoc', "upper right"))
        plt.tight_layout()
        if kwargs.get('saveFigB', False):
            plt.savefig(kwargs.get('outName', 'modelPrediction.png'),
                        orientation='portrait',
                        format='png')
            plt.close()
        if kwargs.get('returnAx', False): return ax