def loadAll(target='confirmed', subGroup='casesGlobal', confIntFilename=PREDICTION_CI_JSON_FILENAME_WORLD, **kwargs): confirmedCasesAll = parseCSSE(target, **kwargs)[subGroup] nTrainedRegions = len( getSavedShortCountryNames( siteData=kwargs.get('siteData', SITE_DATA), confIntFilename=confIntFilename, )) meanPredictionTSAll = pd.DataFrame() percentilesTSAll = pd.DataFrame() for i in range(nTrainedRegions): meanPredictionTS, percentilesTS, regionName = load( i, siteData=kwargs.get('siteData', SITE_DATA), confIntFilename=confIntFilename, ) meanPredictionTS.name = 'meanPrediction' meanPredictionTS = meanPredictionTS.to_frame() meanPredictionTS['regionName'] = regionName percentilesTS['regionName'] = regionName percentilesTSAll = percentilesTSAll.append(percentilesTS) meanPredictionTSAll = meanPredictionTSAll.append(meanPredictionTS) percentilesTSAll = percentilesTSAll.pivot(columns='regionName') meanPredictionTSAll = meanPredictionTSAll.pivot(columns='regionName') meanPredictionTSAll.columns = meanPredictionTSAll.columns.droplevel( level=0) return confirmedCasesAll, meanPredictionTSAll, percentilesTSAll
def test_COUNTRIES_REGIONS_table(): # :o - this uses the actual current list! officialCountriesFileName = resolveReportFileName(SITE_DATA, 'confirmed', '') if not os.path.exists(officialCountriesFileName): parseCSSE('confirmed') countriesCSSE = json.load(open(officialCountriesFileName, 'r')).keys() countriesCheck = [ countries for countries in COUNTRIES_REGIONS.keys() if countries not in countriesCSSE ] assert len(countriesCheck) assert 'Other Region' in countriesCheck
def checkParsing(target): output = parseCSSE( target, siteData=TEST_SITE_DATA, jhCSSEFileConfirmed=TEST_JH_CSSE_FILE_CONFIRMED, jhCSSEFileDeaths=TEST_JH_CSSE_FILE_DEATHS, jhCSSEFileConfirmedUS=TEST_JH_CSSE_FILE_CONFIRMED_US, jhCSSEFileDeathsUS=TEST_JH_CSSE_FILE_DEATHS_US, ) casesGlobal = output['casesGlobal'] casesUSStates = output['casesUSStates'] casesUSRegions = output['casesUSRegions'] casesBoats = output['casesBoats'] casesUSCounties = output['casesUSCounties'] assertDataCompatibility(casesGlobal, casesUSStates, casesUSRegions, casesBoats, casesUSCounties)
def test_parseCSSE(): try: output = parseCSSE( 'confirmed', siteData=TEST_SITE_DATA, jhCSSEFileConfirmedDeprecated= TEST_JH_CSSE_FILE_CONFIRMED_DEPRECATED, jhCSSEFileConfirmed=TEST_JH_CSSE_FILE_CONFIRMED, jsCSSEReportPath=TEST_JH_CSSE_REPORT_PATH, ) casesGlobal = output['casesGlobal'] casesUSStates = output['casesUSStates'] casesUSRegions = output['casesUSRegions'] casesBoats = output['casesBoats'] assertDataCompatibility(casesGlobal, casesUSStates, casesUSRegions, casesBoats) assertValidJSON('confirmed-US-Regions.json') assertValidJSON('confirmed-US.json') assertValidJSON('confirmed-boats.json') assertValidJSON('confirmed.json') except Exception as e: raise e finally: _purge(TEST_SITE_DATA, '.json')
def predictRegions( regionTrainIndex, target='confirmed', predictionsPercentiles=PREDICTIONS_PERCENTILES, siteData=SITE_DATA, subGroup='casesGlobal', jhCSSEFileConfirmed=JH_CSSE_FILE_CONFIRMED, jhCSSEFileDeaths=JH_CSSE_FILE_DEATHS, jhCSSEFileConfirmedDeprecated=JH_CSSE_FILE_CONFIRMED_DEPRECATED, jhCSSEFileDeathsDeprecated=JH_CSSE_FILE_DEATHS_DEPRECATED, jsCSSEReportPath=JH_CSSE_REPORT_PATH, priorLogCarryingCapacity=PRIOR_LOG_CARRYING_CAPACITY, priorMidPoint=PRIOR_MID_POINT, priorGrowthRate=PRIOR_GROWTH_RATE, priorSigma=PRIOR_SIGMA, logRegModel=None, **kwargs): """Generate forecasts for regions Parameters ---------- regionTrainIndex: If an integer, trains the region ranked i+1 in order of total number of cases. If 'all', predicts all regions target: A string in ['confirmed', 'deaths', 'recovered'] predictionsPercentiles: The posterior percentiles to compute siteData: The directory for output data subGroup: jhCSSEFileConfirmed: jhCSSEFileDeaths jhCSSEFileConfirmedDeprecated jhCSSEFileDeathsDeprecated jsCSSEReportPath priorLogCarryingCapacity priorMidPoint priorGrowthRate priorSigma logRegModel kwargs: Optional named arguments for covidvu.predictLogisticGrowth Returns ------- JSON dump of mean prediction and confidence intervals """ if logRegModel is None: print('Building model. This may take a few moments...') logRegModel = buildLogisticModel( priorLogCarryingCapacity=priorLogCarryingCapacity, priorMidPoint=priorMidPoint, priorGrowthRate=priorGrowthRate, priorSigma=priorSigma, ) print('Done.') else: assert isinstance(logRegModel, StanModel) if re.search(r'^\d+$', str(regionTrainIndex)): print(f'Training index {regionTrainIndex}') prediction = predictLogisticGrowth( logRegModel, regionTrainIndex=regionTrainIndex, predictionsPercentiles=predictionsPercentiles, target=target, siteData=siteData, jhCSSEFileConfirmed=jhCSSEFileConfirmed, jhCSSEFileDeaths=jhCSSEFileDeaths, jhCSSEFileConfirmedDeprecated=jhCSSEFileConfirmedDeprecated, jhCSSEFileDeathsDeprecated=jhCSSEFileDeathsDeprecated, jsCSSEReportPath=jsCSSEReportPath, **kwargs) if subGroup == 'casesGlobal': _dumpRegionPrediction(prediction, siteData, predictionsPercentiles) elif subGroup == 'casesUSStates': _dumpRegionPrediction( prediction, siteData, predictionsPercentiles, meanFilename=PREDICTION_MEAN_JSON_FILENAME_US, confIntFilename=PREDICTION_CI_JSON_FILENAME_US, ) else: raise NotImplementedError print('Done.') elif regionTrainIndex == 'all': confirmedCases = parseCSSE( target, siteData=siteData, jhCSSEFileConfirmed=jhCSSEFileConfirmed, jhCSSEFileDeaths=jhCSSEFileDeaths, jhCSSEFileConfirmedDeprecated=jhCSSEFileConfirmedDeprecated, jhCSSEFileDeathsDeprecated=jhCSSEFileDeathsDeprecated, jsCSSEReportPath=jsCSSEReportPath, )[subGroup] regionsAll = confirmedCases.columns[confirmedCases.columns.map( lambda c: c[0] != '!')] for regionName in regionsAll: print(f'Training {regionName}...') prediction = predictLogisticGrowth( logRegModel, regionName=regionName, confirmedCases=confirmedCases, predictionsPercentiles=predictionsPercentiles, target=target, siteData=siteData, jhCSSEFileConfirmed=jhCSSEFileConfirmed, jhCSSEFileDeaths=jhCSSEFileDeaths, jhCSSEFileConfirmedDeprecated=jhCSSEFileConfirmedDeprecated, jhCSSEFileDeathsDeprecated=jhCSSEFileDeathsDeprecated, jsCSSEReportPath=jsCSSEReportPath, **kwargs, ) if prediction: if subGroup == 'casesGlobal': _dumpRegionPrediction(prediction, siteData, predictionsPercentiles) elif subGroup == 'casesUSStates': _dumpRegionPrediction( prediction, siteData, predictionsPercentiles, meanFilename=PREDICTION_MEAN_JSON_FILENAME_US, confIntFilename=PREDICTION_CI_JSON_FILENAME_US, ) else: raise NotImplementedError print('Saved.') else: print('Skipped.') print('Done.') else: raise NotImplementedError
def predictLogisticGrowth(logGrowthModel: StanModel, regionTrainIndex: int = None, regionName: str = None, confirmedCases=None, target='confirmed', subGroup='casesGlobal', nSamples=N_SAMPLES, nChains=N_CHAINS, nDaysPredict=N_DAYS_PREDICT, minCasesFilter=MIN_CASES_FILTER, minNumberDaysWithCases=MIN_NUMBER_DAYS_WITH_CASES, predictionsPercentiles=PREDICTIONS_PERCENTILES, randomSeed=2020, **kwargs): """Predict the region with the nth highest number of cases Parameters ---------- logGrowthModel: A compiled pystan model regionTrainIndex: Order countries from highest to lowest, and train the ith region regionName: Overwrites regionTrainIndex as the region to train confirmedCases: A dataframe of countries as columns, and total number of cases as a time series (see covidvu.vujson.parseCSSE) target: string in ['confirmed', 'deaths', 'recovered'] subGroup: A key in the output of covidvu.pipeline.vujson.parseCSSE nSamples: Number of samples per chain of MCMC nChains: Number of independent chains MCMC nDaysPredict: Number of days ahead to predict minCasesFilter: Minimum number of cases for prediction minNumberDaysWithCases: Minimum number of days with at least minCasesFilter predictionsPercentiles: Bayesian confidence intervals to evaluate randomSeed: Seed for stan sampler kwargs: Optional named arguments passed to covidvu.pipeline.vujson.parseCSSE Returns ------- regionTS: All data for the queried region predictionsMeanTS: Posterior mean prediction predictionsPercentilesTS: Posterior percentiles trace: pymc3 trace object regionTSClean: Data used for training """ maxTreeDepth = kwargs.get('maxTreedepth', MAX_TREEDEPTH) if confirmedCases is None: confirmedCases = parseCSSE( target, siteData=kwargs.get('siteData', SITE_DATA), jhCSSEFileConfirmed=kwargs.get('jhCSSEFileConfirmed', JH_CSSE_FILE_CONFIRMED), jhCSSEFileDeaths=kwargs.get('jhCSSEFileDeaths', JH_CSSE_FILE_DEATHS), jhCSSEFileConfirmedDeprecated=kwargs.get( 'jhCSSEFileConfirmedDeprecated', JH_CSSE_FILE_CONFIRMED_DEPRECATED), jhCSSEFileDeathsDeprecated=kwargs.get( 'jhCSSEFileDeathsDeprecated', JH_CSSE_FILE_DEATHS_DEPRECATED), jsCSSEReportPath=kwargs.get('jsCSSEReportPath', JH_CSSE_REPORT_PATH), )[subGroup] if regionName is None: regionName = _getCountryToTrain(int(regionTrainIndex), confirmedCases) else: assert isinstance(regionName, str) regionTS = confirmedCases[regionName] regionTSClean = regionTS[regionTS > minCasesFilter] if regionTSClean.shape[0] < minNumberDaysWithCases: return None regionTSClean.index = pd.to_datetime(regionTSClean.index) t = np.arange(regionTSClean.shape[0]) regionTSCleanLog = np.log(regionTSClean.values + 1) logisticGrowthData = { 'nDays': regionTSClean.shape[0], 't': list(t), 'casesLog': list(regionTSCleanLog) } fit = logGrowthModel.sampling(data=logisticGrowthData, iter=nSamples, chains=nChains, seed=randomSeed, control={'max_treedepth': maxTreeDepth}) trace = fit.to_dataframe() predictionsMean, predictionsPercentilesTS = _getPredictionsFromPosteriorSamples( t, trace, nDaysPredict, predictionsPercentiles, ) predictionsMeanTS, predictionsPercentilesTS = _castPredictionsAsTS( regionTSClean, nDaysPredict, predictionsMean, predictionsPercentilesTS, ) regionTS.index = pd.to_datetime(regionTS.index) prediction = { 'regionTS': regionTS, 'predictionsMeanTS': predictionsMeanTS, 'predictionsPercentilesTS': predictionsPercentilesTS, 'trace': trace, 'regionTSClean': regionTSClean, 'regionName': regionName, 't': t, } return prediction