コード例 #1
0
def loadAll(target='confirmed',
            subGroup='casesGlobal',
            confIntFilename=PREDICTION_CI_JSON_FILENAME_WORLD,
            **kwargs):
    confirmedCasesAll = parseCSSE(target, **kwargs)[subGroup]
    nTrainedRegions = len(
        getSavedShortCountryNames(
            siteData=kwargs.get('siteData', SITE_DATA),
            confIntFilename=confIntFilename,
        ))
    meanPredictionTSAll = pd.DataFrame()
    percentilesTSAll = pd.DataFrame()
    for i in range(nTrainedRegions):
        meanPredictionTS, percentilesTS, regionName = load(
            i,
            siteData=kwargs.get('siteData', SITE_DATA),
            confIntFilename=confIntFilename,
        )
        meanPredictionTS.name = 'meanPrediction'
        meanPredictionTS = meanPredictionTS.to_frame()
        meanPredictionTS['regionName'] = regionName
        percentilesTS['regionName'] = regionName
        percentilesTSAll = percentilesTSAll.append(percentilesTS)
        meanPredictionTSAll = meanPredictionTSAll.append(meanPredictionTS)
    percentilesTSAll = percentilesTSAll.pivot(columns='regionName')
    meanPredictionTSAll = meanPredictionTSAll.pivot(columns='regionName')
    meanPredictionTSAll.columns = meanPredictionTSAll.columns.droplevel(
        level=0)
    return confirmedCasesAll, meanPredictionTSAll, percentilesTSAll
コード例 #2
0
ファイル: test_vuregions.py プロジェクト: maaand/COVIDvu
def test_COUNTRIES_REGIONS_table():
    # :o - this uses the actual current list!
    officialCountriesFileName = resolveReportFileName(SITE_DATA, 'confirmed',
                                                      '')

    if not os.path.exists(officialCountriesFileName):
        parseCSSE('confirmed')

    countriesCSSE = json.load(open(officialCountriesFileName, 'r')).keys()

    countriesCheck = [
        countries for countries in COUNTRIES_REGIONS.keys()
        if countries not in countriesCSSE
    ]

    assert len(countriesCheck)
    assert 'Other Region' in countriesCheck
コード例 #3
0
ファイル: test_vujson.py プロジェクト: VirusTrack/COVIDvu
def checkParsing(target):
    output = parseCSSE(
        target,
        siteData=TEST_SITE_DATA,
        jhCSSEFileConfirmed=TEST_JH_CSSE_FILE_CONFIRMED,
        jhCSSEFileDeaths=TEST_JH_CSSE_FILE_DEATHS,
        jhCSSEFileConfirmedUS=TEST_JH_CSSE_FILE_CONFIRMED_US,
        jhCSSEFileDeathsUS=TEST_JH_CSSE_FILE_DEATHS_US,
    )
    casesGlobal = output['casesGlobal']
    casesUSStates = output['casesUSStates']
    casesUSRegions = output['casesUSRegions']
    casesBoats = output['casesBoats']
    casesUSCounties = output['casesUSCounties']

    assertDataCompatibility(casesGlobal, casesUSStates, casesUSRegions,
                            casesBoats, casesUSCounties)
コード例 #4
0
ファイル: test_vujson.py プロジェクト: maaand/COVIDvu
def test_parseCSSE():
    try:
        output = parseCSSE(
            'confirmed',
            siteData=TEST_SITE_DATA,
            jhCSSEFileConfirmedDeprecated=
            TEST_JH_CSSE_FILE_CONFIRMED_DEPRECATED,
            jhCSSEFileConfirmed=TEST_JH_CSSE_FILE_CONFIRMED,
            jsCSSEReportPath=TEST_JH_CSSE_REPORT_PATH,
        )
        casesGlobal = output['casesGlobal']
        casesUSStates = output['casesUSStates']
        casesUSRegions = output['casesUSRegions']
        casesBoats = output['casesBoats']
        assertDataCompatibility(casesGlobal, casesUSStates, casesUSRegions,
                                casesBoats)
        assertValidJSON('confirmed-US-Regions.json')
        assertValidJSON('confirmed-US.json')
        assertValidJSON('confirmed-boats.json')
        assertValidJSON('confirmed.json')
    except Exception as e:
        raise e
    finally:
        _purge(TEST_SITE_DATA, '.json')
コード例 #5
0
def predictRegions(
        regionTrainIndex,
        target='confirmed',
        predictionsPercentiles=PREDICTIONS_PERCENTILES,
        siteData=SITE_DATA,
        subGroup='casesGlobal',
        jhCSSEFileConfirmed=JH_CSSE_FILE_CONFIRMED,
        jhCSSEFileDeaths=JH_CSSE_FILE_DEATHS,
        jhCSSEFileConfirmedDeprecated=JH_CSSE_FILE_CONFIRMED_DEPRECATED,
        jhCSSEFileDeathsDeprecated=JH_CSSE_FILE_DEATHS_DEPRECATED,
        jsCSSEReportPath=JH_CSSE_REPORT_PATH,
        priorLogCarryingCapacity=PRIOR_LOG_CARRYING_CAPACITY,
        priorMidPoint=PRIOR_MID_POINT,
        priorGrowthRate=PRIOR_GROWTH_RATE,
        priorSigma=PRIOR_SIGMA,
        logRegModel=None,
        **kwargs):
    """Generate forecasts for regions

    Parameters
    ----------
    regionTrainIndex: If an integer, trains the region ranked i+1 in order of total number of cases. If 'all',
        predicts all regions
    target: A string in ['confirmed', 'deaths', 'recovered']
    predictionsPercentiles: The posterior percentiles to compute
    siteData: The directory for output data
    subGroup:
    jhCSSEFileConfirmed:
    jhCSSEFileDeaths
    jhCSSEFileConfirmedDeprecated
    jhCSSEFileDeathsDeprecated
    jsCSSEReportPath
    priorLogCarryingCapacity
    priorMidPoint
    priorGrowthRate
    priorSigma
    logRegModel
    kwargs: Optional named arguments for covidvu.predictLogisticGrowth

    Returns
    -------
    JSON dump of mean prediction and confidence intervals
    """
    if logRegModel is None:
        print('Building model. This may take a few moments...')
        logRegModel = buildLogisticModel(
            priorLogCarryingCapacity=priorLogCarryingCapacity,
            priorMidPoint=priorMidPoint,
            priorGrowthRate=priorGrowthRate,
            priorSigma=priorSigma,
        )
        print('Done.')
    else:
        assert isinstance(logRegModel, StanModel)

    if re.search(r'^\d+$', str(regionTrainIndex)):
        print(f'Training index {regionTrainIndex}')
        prediction = predictLogisticGrowth(
            logRegModel,
            regionTrainIndex=regionTrainIndex,
            predictionsPercentiles=predictionsPercentiles,
            target=target,
            siteData=siteData,
            jhCSSEFileConfirmed=jhCSSEFileConfirmed,
            jhCSSEFileDeaths=jhCSSEFileDeaths,
            jhCSSEFileConfirmedDeprecated=jhCSSEFileConfirmedDeprecated,
            jhCSSEFileDeathsDeprecated=jhCSSEFileDeathsDeprecated,
            jsCSSEReportPath=jsCSSEReportPath,
            **kwargs)
        if subGroup == 'casesGlobal':
            _dumpRegionPrediction(prediction, siteData, predictionsPercentiles)
        elif subGroup == 'casesUSStates':
            _dumpRegionPrediction(
                prediction,
                siteData,
                predictionsPercentiles,
                meanFilename=PREDICTION_MEAN_JSON_FILENAME_US,
                confIntFilename=PREDICTION_CI_JSON_FILENAME_US,
            )
        else:
            raise NotImplementedError

        print('Done.')

    elif regionTrainIndex == 'all':
        confirmedCases = parseCSSE(
            target,
            siteData=siteData,
            jhCSSEFileConfirmed=jhCSSEFileConfirmed,
            jhCSSEFileDeaths=jhCSSEFileDeaths,
            jhCSSEFileConfirmedDeprecated=jhCSSEFileConfirmedDeprecated,
            jhCSSEFileDeathsDeprecated=jhCSSEFileDeathsDeprecated,
            jsCSSEReportPath=jsCSSEReportPath,
        )[subGroup]
        regionsAll = confirmedCases.columns[confirmedCases.columns.map(
            lambda c: c[0] != '!')]
        for regionName in regionsAll:
            print(f'Training {regionName}...')

            prediction = predictLogisticGrowth(
                logRegModel,
                regionName=regionName,
                confirmedCases=confirmedCases,
                predictionsPercentiles=predictionsPercentiles,
                target=target,
                siteData=siteData,
                jhCSSEFileConfirmed=jhCSSEFileConfirmed,
                jhCSSEFileDeaths=jhCSSEFileDeaths,
                jhCSSEFileConfirmedDeprecated=jhCSSEFileConfirmedDeprecated,
                jhCSSEFileDeathsDeprecated=jhCSSEFileDeathsDeprecated,
                jsCSSEReportPath=jsCSSEReportPath,
                **kwargs,
            )
            if prediction:
                if subGroup == 'casesGlobal':
                    _dumpRegionPrediction(prediction, siteData,
                                          predictionsPercentiles)
                elif subGroup == 'casesUSStates':
                    _dumpRegionPrediction(
                        prediction,
                        siteData,
                        predictionsPercentiles,
                        meanFilename=PREDICTION_MEAN_JSON_FILENAME_US,
                        confIntFilename=PREDICTION_CI_JSON_FILENAME_US,
                    )
                else:
                    raise NotImplementedError
                print('Saved.')
            else:
                print('Skipped.')
        print('Done.')
    else:
        raise NotImplementedError
コード例 #6
0
def predictLogisticGrowth(logGrowthModel: StanModel,
                          regionTrainIndex: int = None,
                          regionName: str = None,
                          confirmedCases=None,
                          target='confirmed',
                          subGroup='casesGlobal',
                          nSamples=N_SAMPLES,
                          nChains=N_CHAINS,
                          nDaysPredict=N_DAYS_PREDICT,
                          minCasesFilter=MIN_CASES_FILTER,
                          minNumberDaysWithCases=MIN_NUMBER_DAYS_WITH_CASES,
                          predictionsPercentiles=PREDICTIONS_PERCENTILES,
                          randomSeed=2020,
                          **kwargs):
    """Predict the region with the nth highest number of cases

    Parameters
    ----------
    logGrowthModel: A compiled pystan model
    regionTrainIndex: Order countries from highest to lowest, and train the ith region
    regionName: Overwrites regionTrainIndex as the region to train
    confirmedCases: A dataframe of countries as columns, and total number of cases as a time series
        (see covidvu.vujson.parseCSSE)
    target: string in ['confirmed', 'deaths', 'recovered']
    subGroup: A key in the output of covidvu.pipeline.vujson.parseCSSE
    nSamples: Number of samples per chain of MCMC
    nChains: Number of independent chains MCMC
    nDaysPredict: Number of days ahead to predict
    minCasesFilter: Minimum number of cases for prediction
    minNumberDaysWithCases: Minimum number of days with at least minCasesFilter
    predictionsPercentiles: Bayesian confidence intervals to evaluate
    randomSeed: Seed for stan sampler
    kwargs: Optional named arguments passed to covidvu.pipeline.vujson.parseCSSE

    Returns
    -------
    regionTS: All data for the queried region
    predictionsMeanTS: Posterior mean prediction
    predictionsPercentilesTS: Posterior percentiles
    trace: pymc3 trace object
    regionTSClean: Data used for training
    """
    maxTreeDepth = kwargs.get('maxTreedepth', MAX_TREEDEPTH)

    if confirmedCases is None:
        confirmedCases = parseCSSE(
            target,
            siteData=kwargs.get('siteData', SITE_DATA),
            jhCSSEFileConfirmed=kwargs.get('jhCSSEFileConfirmed',
                                           JH_CSSE_FILE_CONFIRMED),
            jhCSSEFileDeaths=kwargs.get('jhCSSEFileDeaths',
                                        JH_CSSE_FILE_DEATHS),
            jhCSSEFileConfirmedDeprecated=kwargs.get(
                'jhCSSEFileConfirmedDeprecated',
                JH_CSSE_FILE_CONFIRMED_DEPRECATED),
            jhCSSEFileDeathsDeprecated=kwargs.get(
                'jhCSSEFileDeathsDeprecated', JH_CSSE_FILE_DEATHS_DEPRECATED),
            jsCSSEReportPath=kwargs.get('jsCSSEReportPath',
                                        JH_CSSE_REPORT_PATH),
        )[subGroup]

    if regionName is None:
        regionName = _getCountryToTrain(int(regionTrainIndex), confirmedCases)
    else:
        assert isinstance(regionName, str)

    regionTS = confirmedCases[regionName]
    regionTSClean = regionTS[regionTS > minCasesFilter]
    if regionTSClean.shape[0] < minNumberDaysWithCases:
        return None

    regionTSClean.index = pd.to_datetime(regionTSClean.index)

    t = np.arange(regionTSClean.shape[0])
    regionTSCleanLog = np.log(regionTSClean.values + 1)

    logisticGrowthData = {
        'nDays': regionTSClean.shape[0],
        't': list(t),
        'casesLog': list(regionTSCleanLog)
    }

    fit = logGrowthModel.sampling(data=logisticGrowthData,
                                  iter=nSamples,
                                  chains=nChains,
                                  seed=randomSeed,
                                  control={'max_treedepth': maxTreeDepth})

    trace = fit.to_dataframe()

    predictionsMean, predictionsPercentilesTS = _getPredictionsFromPosteriorSamples(
        t,
        trace,
        nDaysPredict,
        predictionsPercentiles,
    )

    predictionsMeanTS, predictionsPercentilesTS = _castPredictionsAsTS(
        regionTSClean,
        nDaysPredict,
        predictionsMean,
        predictionsPercentilesTS,
    )

    regionTS.index = pd.to_datetime(regionTS.index)
    prediction = {
        'regionTS': regionTS,
        'predictionsMeanTS': predictionsMeanTS,
        'predictionsPercentilesTS': predictionsPercentilesTS,
        'trace': trace,
        'regionTSClean': regionTSClean,
        'regionName': regionName,
        't': t,
    }

    return prediction