예제 #1
0
def PlotTimeShift(sysargv):

	"""
		:Example:

		For countries (European database)
		>> python PlotTimeShift.py 
		>> python PlotTimeShift.py France      0 SEIR1R2  0 18,19 1 1
		>> python PlotTimeShift.py Italy,Spain 1 SEIR1R2D 1 18,19 0 1 # Italy and Spain, with UKF filtering

		For French Region (French database)
		>> python PlotTimeShift.py FRANCE,D69         0 SEIR1R2  0 18,19 0 1 # Code Insee Dpt 69 (Rhône)
		>> python PlotTimeShift.py FRANCE,R84         0 SEIR1R2  0 18,19 0 1 # Tous les dpts de la Région dont le code Insee est  regions)
		>> python PlotTimeShift.py FRANCE,R32+        0 SEIR1R2  0 18,19 0 1 # Somme de tous les dpts de la Région 32 (Hauts de F French regions)
		>> python PlotTimeShift.py FRANCE,MetropoleD  0 SEIR1R2  0 18,19 0 1 # Tous les départements de la France métropolitaine
		>> python PlotTimeShift.py FRANCE,MetropoleD+ 0 SEIR1R2  0 18,19 0 1 # Toute la France métropolitaine (en sommant les dpts)
		>> python PlotTimeShift.py FRANCE,MetropoleR+ 0 SEIR1R2  0 18,19 0 1 # Somme des dpts de toutes les régions françaises
		Toute combinaison de lieux est possible : exemple FRANCE,R32+,D05,R84
		
		argv[1] : List of countries (ex. France,Germany,Italy), or see above.  Default: France 
		argv[2] : Sex (male:1, female:2, male+female:0). Only for french database     Default: 0 
		argv[3] : EDO model (SEIR1R2 or SEIR1R2D).                             Default: SEIR2R2         
		argv[4] : UKF filtering of data (0/1).                                 Default: 0
		argv[5] : min shift, max shift, ex. 2,10.                              Default: 18,19
		argv[6] : Verbose level (debug: 3, ..., almost mute: 0).               Default: 1
		argv[7] : Plot graphique (0/1).                                        Default: 1
		argv[8] : stopDate.                                                    Default: None
	"""

	#Austria,Belgium,Croatia,Czechia,Finland,France,Germany,Greece,Hungary,Ireland,Italy,Lithuania,Poland,Portugal,Romania,Serbia,Spain,Switzerland,Ukraine
	#Austria,Belgium,Croatia,Czechia,Finland,France,Germany,Greece,Hungary,Ireland,Italy,Poland,Portugal,Romania,Serbia,Spain,Switzerland,Ukraine
	# Il y a 18 pays
	
	# Interpretation of arguments - reparation
	######################################################@


	SMALL_SIZE  = 16
	MEDIUM_SIZE = 20
	BIGGER_SIZE = 24

	plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
	plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
	plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
	plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
	plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
	plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

	if len(sysargv) > 9:
		print('  CAUTION : bad number of arguments - see help')
		exit(1)

	# Default value for parameters
	places                 = 'France'
	sexe, sexestr          = 0, 'male+female'
	listplaces             = list(places.split(','))
	modeleString           = 'SEIR1R2'
	UKF_filt, UKF_filt01   = False, 0  #True, 1
	shift_mini, shift_maxi = 18,19     #4,18
	verbose                = 1
	plot                   = True
	readStopDateStr        = "None" #"2020-07-01"
	
	# Parameters from argv
	if len(sysargv)>1: places, liste = sysargv[1], list(sysargv[1].split(','))
	if len(sysargv)>2: sexe = int(sysargv[2])
	if len(sysargv)>3: modeleString = sysargv[3]
	if len(sysargv)>4 and int(sysargv[4])==1: UKF_filt, UKF_filt01 = True, 1
	if len(sysargv)>5: shift_mini, shift_maxi = map(int, sysargv[5].split(','))
	if len(sysargv)>6: verbose = int(sysargv[6])
	if len(sysargv)>7 and int(sysargv[7])==0: plot = False
	if len(sysargv)>8: readStopDateStr = sysargv[8]
	if shift_maxi-shift_mini==1:
		plot = False # Pas de plot possible s'il n'y a qu'une seule données
	if sexe not in [0,1,2]:	sexe, sexestr = 0, 'male+female'      # sexe indiférencié
	if sexe == 1: sexestr = 'male'
	if sexe == 2: sexestr = 'female'

	listplaces = []
	listnames  = []
	if liste[0]=='FRANCE':
		FrDatabase = True
		liste = liste[1:]
		for el in liste:
			l,n=getPlace(el)
			if el=='MetropoleR+':
				for l1,n1 in zip(l,n):
					listplaces.extend(l1)
					listnames.extend([n1])
			else:
				listplaces.extend(l)
				listnames.extend(n)
		places = [el[0] for el in listnames]
		places = 'FRANCE,'+','.join(places)
	else:
		listplaces = liste[:]
		FrDatabase = False

	# le modèle à traiter (SEIR1R2 or SEIR1R2D)
	if modeleString == 'SEIR1R2':
		fit           = fitProcessSEIR1R2
		modeleString2 =f'SEIR\N{SUPERSCRIPT ONE}R\N{SUPERSCRIPT TWO}'
	elif modeleString == 'SEIR1R2D':
		fit = fitProcessSEIR1R2D
		modeleString2 =f'SEIR\N{SUPERSCRIPT ONE}R\N{SUPERSCRIPT TWO}D'
	else:
		print('Wrong EDO model, only SEIR1R2 or SEIR1R2D available!')
		exit(1)

	if verbose>0:
		print('  Full command line : '+sysargv[0]+' '+places+' '+str(sexe)+' '+modeleString+' '+str(UKF_filt)+' '+str(shift_mini)+','+str(shift_maxi)+' '+str(verbose)+' '+str(plot), flush=True)
	

	# fit avec 3 périodes + décalage
	##################################
	nbperiodes = -1
	
	TAB_decalage    = []
	TAB_param_model = []
	TAB_IEnd = []
	TAB_ListeEQM    = []
	TAB_ListeDateI0 = []

	for decalage in range(shift_mini, shift_maxi):

		if verbose>0:
			print('TIME-SHIFT', str(decalage), 'OVER', str(shift_maxi))
	
		_, _, _, _, _, tabParamModel, tabIEnd, ListeEQM, ListeDateI0 = fit([places, sexe, nbperiodes, decalage, UKF_filt, 0, 0, readStopDateStr])

		TAB_decalage.append(float(decalage))
		TAB_param_model.append(tabParamModel)
		TAB_IEnd.append(tabIEnd)
		TAB_ListeEQM.append(ListeEQM)
		TAB_ListeDateI0.append(ListeDateI0)

	#TAB_param_model[decalage][place][nbperiodes]
	#input('apuse')

	X = np.linspace(shift_mini, shift_maxi-1, shift_maxi-shift_mini)
	if modeleString == 'SEIR1R2':
		labelsparam  = [r'$a$', r'$b$', r'$c$', r'$f$', r'$R_0$']
	else:
		labelsparam  = [r'$a$', r'$b$', r'$c$', r'$f$', r'$\mu$', r'$\xi$', r'$R_0$']

	# On enregistre le R0 moyen de la 3ieme période pour faire carte graphique
	rep  = getRepertoire(UKF_filt, './figures/'+modeleString+'_UKFilt/TimeShift/', './figures/'+modeleString+'/TimeShift/')
	fileR0moyen = rep+'/R0Moyen_' + str(shift_mini)+ '_' + str(shift_maxi) +'.csv'
	if os.path.exists(fileR0moyen):
		os.remove(fileR0moyen)
	with open(fileR0moyen, 'a') as text_file:
		text_file.write('Place,R0MoyenP0,R0MoyenP1,R0MoyenP2,R0MoyenP3,DateFirstCase,IEndP0,IEndP1,IEndP2,IEndP3\n')


	ListeChaines = []
	for indexplace, place in enumerate(listplaces):

		# Get the full name of the place to process, and the special dates corresponding to the place
		if FrDatabase == True: 
			if 'MetropoleD+' in listnames[indexplace][0]:
				placefull = 'France'
			else:
				placefull  = 'France-' + listnames[indexplace][0]
		else:
			placefull  = place

		# Repertoire des figures
		if plot==True:
			ch1 = placefull+'/sexe_'+str(sexe)+'_delay_'+str(shift_mini)+'_'+str(shift_maxi)
			repertoire = getRepertoire(UKF_filt, './figures/'+modeleString+'_UKFilt/'+ch1, './figures/'+modeleString+'/'+ch1)
			prefFig    = repertoire+'/'

		nbperiodes = len(TAB_param_model[0][indexplace][:])
		labelsperiod = []
		for p in range(nbperiodes):
			labelsperiod.append('Period '+str(p))

		# plot pour les nbperiodes périodes
		##########################################@
		Y1 = np.zeros(shape=(shift_maxi-shift_mini, len(labelsparam)))
		for period in range(nbperiodes):

			for decalage in range(shift_maxi-shift_mini):
				try:
					Y1[decalage, :] = TAB_param_model[decalage][indexplace][period][:]
				except IndexError:
					Y1[decalage, :] = 0.

			if plot==True:
				if sexe==0:
					titre = placefull + ' - ' + modeleString2 + ' parameters evolution for ' + labelsperiod[period]
				else:
					titre = placefull + ' - Sex=' + sexestr + ', ' + modeleString2 + ' parameters evolution for ' + labelsperiod[period]
				texte = list(map( lambda s: s.replace('$', '').replace('\\', '').replace('_', ''), labelsparam[:-1]))
				filename = prefFig   + 'Plot_TS_Period' + str(period) + '_' + ''.join(texte) + '.png'
				plotData(TAB_decalage, Y1[:, :-1], titre, filename, labelsparam[:-1])
				filename = prefFig   + 'Plot_TS_Period' + str(period) + '_R0.png'
				plotData(TAB_decalage, Y1[:, -1].reshape(shift_maxi-shift_mini, 1), titre, filename, [labelsparam[-1]])


		# plot pour les paramètres
		##########################################
		if plot==True:
			if os.path.exists(prefFig+'Plot_TS.txt'):
				os.remove(prefFig+'Plot_TS.txt')

		Y2 = np.zeros(shape=(shift_maxi-shift_mini, nbperiodes))
		for param in range(len(labelsparam)):

			for decalage in range(shift_maxi-shift_mini):
				for period in range(nbperiodes):
					try:
						Y2[decalage, period] = np.round(TAB_param_model[decalage][indexplace][period][param], 3)
					except IndexError:
						Y2[decalage, period] = 0.
			if plot==True:
				if sexe==0:
					titre = placefull + ' - ' + modeleString2 + ' periods evolution for param ' + labelsparam[param]
				else:
					titre = placefull + ' - Sex=' + sexestr + ', ' + modeleString2 + ' periods evolution for param ' + labelsparam[param]
				filename = prefFig   + 'Plot_TS_Param' + labelsparam[param].replace('$', '') + '.png'
				plotData(TAB_decalage, Y2, titre, filename, labelsperiod)

				# Write parameters in a file
				with open(prefFig+'Plot_TS.txt', 'a') as text_file:
					text_file.write('\n\nParam: %s' % labelsparam[param].replace('$', '').replace('\\', ''))
					for period in range(nbperiodes):
						#text_file.write('\n  -->%s:\n' % labelsperiod[period])
						np.savetxt(text_file, Y2[:, period], delimiter=', ', newline=', ', fmt='%.4f', header='\n  -->'+labelsperiod[period]+': ')
			
			if param==len(labelsparam)-1: # c'est à dire R0
				with open(fileR0moyen, 'a') as text_file:
					Lieu = placefull
					if placefull[0]=='D' or placefull[0]=='R':
						Lieu = Lieu[1:]
					if placefull[-1]=='+':
						Lieu = Lieu[:-1]
					if Lieu[-1]=='D':
						Lieu = Lieu[:-1]
					chaine = Lieu+','
					# Les R0 moyens pour les 3 périodes
					for period in range(nbperiodes):
						if -1. in Y2[:, period]:
							R0Est = -1.
						else:
							R0Est = sorted(Y2[:, period])[int((len(Y2[:, period])-1)/2)]
							if period==0:
								Indice = np.where(Y2[:, period]==R0Est)[0][0]
						chaine += str(R0Est) + ','
					if nbperiodes==3: chaine += ','
					
					# La date du premier infecté
					#Si -1. pour la periode 0, alors pas de date
					if -1. in Y2[:, 0]:
						chaine += 'Invalid'
					else:
						chaine += TAB_ListeDateI0[Indice][indexplace]

					# Les infectés en fin de période pour les 3 périodes
					for period in range(nbperiodes):
						chaine += ',' + str(TAB_IEnd[decalage][indexplace][period])
					if nbperiodes==3: chaine += ','
					
					text_file.write(chaine+'\n')

					if verbose>0:
						print('chaine=', chaine)
					ListeChaines.append(chaine)


		# plot de l'EQM
		##########################################
		if plot==True:

			fig = plt.figure(facecolor='w', figsize=figsize)
			ax  = fig.add_subplot(111, facecolor='#dddddd', axisbelow=True)
			
			Y3 = np.zeros(shape=(len(X)))
			for k in range(len(Y3)):
				try:
					Y3[k] = TAB_ListeEQM[k][indexplace]
				except IndexError:
					Y3[k] = 0.
			ax.plot(TAB_decalage, Y3, alpha=1.0, lw=2, label='MSE for ' + f'R\N{SUPERSCRIPT ONE}')

			ax.set_xlabel('Delay (delta) in days')
			ax.yaxis.set_tick_params(length=0)
			ax.xaxis.set_tick_params(length=0)
			ax.grid(b=True, which='major', c='w', lw=1, ls='-')
			ax.xaxis.set_major_locator(MaxNLocator(integer=True))
			
			legend = ax.legend()
			legend.get_frame().set_alpha(0.5)
			for spine in ('top', 'right', 'bottom', 'left'):
				ax.spines[spine].set_visible(False)

			plt.xlim([TAB_decalage[0], TAB_decalage[-1]])
			#plt.ylim([0, 1.0])

			# ajout d'un text d'annotation
			if sexe==0:
				titre = placefull + ' - ' + modeleString2# + ', EQM on ' + f'R\N{SUPERSCRIPT ONE}'
			else:
				titre = placefull + ' - Sex=' + sexestr + ', ' + modeleString2# + ', EQM on ' + f'R\N{SUPERSCRIPT ONE}'
			plt.tight_layout(rect=(0, 0.03, 1., 0.95))
			plt.title(titre)
			plt.savefig(prefFig + 'Plot_TS_EQM_cumul.png', dpi=dpi)
			plt.close()

	# Plot des distributions des paramètres
	##########################################

	if plot==True and len(listplaces)>1:
		rep = getRepertoire(UKF_filt, './figures/'+modeleString+'_UKFilt/TimeShift/', './figures/'+modeleString+'/TimeShift/')
		file = rep+'/Distrib_'

		for i, param in enumerate(labelsparam):
			# print('Param:', param, ', i=', i)
			xlim = True
			if i==len(labelsparam)-1:
				xlim = False
			for decalage in range(shift_maxi-shift_mini):
				Y3 = np.zeros(shape=(nbperiodes, len(listplaces)))
				for period in range(nbperiodes):
					for indexplace, place in enumerate(listplaces):
						Y3[period, indexplace] = TAB_param_model[decalage][indexplace][period][i]

				titre = 'Distribution of parameter ' +  param + ' for France departements'
				filename = file + str(decalage+shift_mini)+'_param'+param
				PlotDistribParam(Y3, titre, filename, labelsperiod, xlim)

	return ListeChaines
예제 #2
0
def main(sysargv):
    """
        Program to plot data and generate figures on Covid Data.
 
        :Example:

        For country all around the world (European database)
        >> python PlotDataCovid.py 
        >> python PlotDataCovid.py United_Kingdom
        >> python PlotDataCovid.py Italy 2 1        # Only Italian women
        >> python PlotDataCovid.py France,Germany 1 # Shortcut for processing the two countries successively
        >> python PlotDataCovid.py France,Spain,Italy,United_Kingdom,Germany,Belgium 0

        For French geographical areas (French database)
        >> python PlotDataCovid.py FRANCE,D69         # Department 69 (Rhône), INSEE numbering
        >> python PlotDataCovid.py FRANCE,R84         # Process successively all the dpts of region #84 ('Auvergne-Rhone-Alpes')
        >> python PlotDataCovid.py FRANCE,R32+        # Sum od all dpts of Région 32 ('Hauts de France')
        >> python PlotDataCovid.py FRANCE,MetropoleD  # All the dpts of the metropolitan France
        >> python PlotDataCovid.py FRANCE,MetropoleD+ # France (by summing dpts)
        >> python PlotDataCovid.py FRANCE,MetropoleR+ # All the regions by summing their dpts
        Every combination is possible, e.g.: FRANCE,R32+,D05,R84

        argv[1] : List of countries (ex. France,Germany,Italy), or see above for France. Default: France 
        argv[2] : Sex (male:1, female:2, male+female:0). Only for french database        Default: 0
        argv[3] : Verbose level (debug: 3, ..., almost mute: 0).                         Default: 1
    """

    #Austria,Belgium,Croatia,Czechia,Finland,France,Germany,Greece,Hungary,Ireland,Italy,Lithuania,Poland,Portugal,Romania,Serbia,Spain,Switzerland,Ukraine

    print('Command line : ', sysargv, flush=True)
    if len(sysargv) > 4:
        print('  CAUTION : bad number of arguments - see help')
        exit(1)

    # Constants
    ######################################################@
    dt = 1
    readStartDateStr = "2020-03-01"
    readStopDateStr = None
    France = 'France'

    # Interpetation of arguments - reparation
    ######################################################@

    # Default value for parameters
    listplaces = ['France']
    sexe, sexestr = 0, 'male+female'
    verbose = 1

    # Parameters from argv
    if len(sysargv) > 1: liste = list(sysargv[1].split(','))
    if len(sysargv) > 2: sexe = int(sysargv[2])
    if len(sysargv) > 3: verbose = int(sysargv[3])
    if sexe not in [0, 1, 2]: sexe, sexestr = 0, 'male+female'
    if sexe == 1: sexestr = 'male'
    if sexe == 2: sexestr = 'female'

    # List iof places to process
    listplaces = []
    listnames = []
    if liste[0] == 'FRANCE':
        FrDatabase = True
        liste = liste[1:]
        for el in liste:
            l, n = getPlace(el)
            if el == 'MetropoleR+':
                for l1, n1 in zip(l, n):
                    listplaces.extend(l1)
                    listnames.extend([n1])
            else:
                listplaces.extend(l)
                listnames.extend(n)
    else:
        listplaces = liste[:]
        FrDatabase = False

    # Loop for all places
    ############################################################@
    for indexplace, place in enumerate(listplaces):

        # Get the full name of the place to process, and the special dates corresponding to the place
        if FrDatabase == True:
            placefull = 'France-' + listnames[indexplace][0]
            DatesString = readDates(France, verbose)
        else:
            placefull = place
            DatesString = readDates(place, verbose)

        # Figures repository
        repertoire = getRepertoire(
            True, './figures/data/' + placefull + '/sexe_' + str(sexe))
        prefFig = repertoire + '/'

        # Data reading and plot
        ##########################################################
        if FrDatabase == True:
            pd_exerpt, HeadData, N, readStartDateStr, readStopDateStr, dateFirstNonZeroStr = readDataFrance(
                place,
                readStartDateStr,
                readStopDateStr,
                fileLocalCopy,
                sexe,
                verbose=0)
        else:
            pd_exerpt, HeadData, N, readStartDateStr, readStopDateStr, dateFirstNonZeroStr = readDataEurope(
                place,
                readStartDateStr,
                readStopDateStr,
                fileLocalCopy,
                verbose=0)

        readStartDate = datetime.strptime(readStartDateStr, "%Y-%m-%d")
        readStopDate = datetime.strptime(readStopDateStr, "%Y-%m-%d")
        dataLength = pd_exerpt.shape[0]
        if verbose > 0:
            print('readStartDateStr=', readStartDateStr, ', readStopDateStr=',
                  readStopDateStr)
            print('readStartDate   =', readStartDate, ', readStopDate   =',
                  readStopDate)
            print('dateFirstNonZeroStr=', dateFirstNonZeroStr)
            #input('pause')

        # Adding the gradient
        pd_exerpt['Diff ' + HeadData[0]] = pd_exerpt[HeadData[0]].diff()
        pd_exerpt['Diff ' + HeadData[1]] = pd_exerpt[HeadData[1]].diff()
        pd_exerpt['Diff ' + HeadData[2]] = pd_exerpt[HeadData[2]].diff()

        # Plot and store the figures in the directory
        if sexe == 0:
            titre = placefull
        else:
            titre = placefull + ' - Sex=' + sexestr
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + HeadData[0] + '.png',
                 y=HeadData[0],
                 color='red',
                 Dates=DatesString)
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + HeadData[1] + '.png',
                 y=HeadData[1],
                 color='black',
                 Dates=DatesString)
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + HeadData[2] + '.png',
                 y=HeadData[2],
                 color='black',
                 Dates=DatesString)
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + HeadData[0] + HeadData[1] + '.png',
                 y=[HeadData[0], HeadData[1]],
                 color=['red', 'black'],
                 Dates=DatesString)
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + 'Diff' + HeadData[0] + '.png',
                 y=['Diff ' + HeadData[0]],
                 color='red',
                 Dates=DatesString)
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + 'Diff' + HeadData[1] + '.png',
                 y=['Diff ' + HeadData[1]],
                 color='black',
                 Dates=DatesString)
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + 'Diff' + HeadData[0] + HeadData[1] +
                 '.png',
                 y=['Diff ' + HeadData[0], 'Diff ' + HeadData[1]],
                 color=['red', 'black'],
                 Dates=DatesString)
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + 'Diff' + HeadData[0] + HeadData[1] +
                 HeadData[2] + '.png',
                 y=[
                     'Diff ' + HeadData[0], 'Diff ' + HeadData[1],
                     'Diff ' + HeadData[2]
                 ],
                 color=['red', 'black', 'blue'],
                 Dates=DatesString)

        # Data filtering and plot
        ######################################################

        # R1+D filtering by UKF
        data = pd_exerpt[HeadData[2]].tolist()
        dt = 1
        sigmas = MerweScaledSigmaPoints(n=1, alpha=.5, beta=2.,
                                        kappa=0.)  #1-3.)
        ukf = UKF(dim_x=1, dim_z=1, fx=fR1, hx=hR1, dt=dt, points=sigmas)
        # Filter init
        ukf.x[0] = data[0]
        ukf.Q = np.diag([30.])
        ukf.R = np.diag([170.])
        if verbose > 1:
            print('ukf.x[0]=', ukf.x[0])
            print('ukf.R   =', ukf.R)
            print('ukf.Q   =', ukf.Q)

        # UKF filtering and smoothing, batch mode
        R1filt, _ = ukf.batch_filter(data)

        # plotting
        pd_exerpt[HeadData[2] + ' filt'] = R1filt
        if sexe == 0:
            titre = placefull
        else:
            titre = placefull + ' - Sex=' + sexestr
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + 'filt' + HeadData[2] + '.png',
                 y=[HeadData[2], HeadData[2] + ' filt'],
                 color=['red', 'darkred'],
                 Dates=DatesString)
        pd_exerpt['Diff ' + HeadData[2] + ' filt'] = pd_exerpt[HeadData[2] +
                                                               ' filt'].diff()
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + 'diff_filt' + HeadData[2] + '.png',
                 y=['Diff ' + HeadData[2], 'Diff ' + HeadData[2] + ' filt'],
                 color=['red', 'darkred'],
                 Dates=DatesString)

        # Diff R1 filtering by UKF
        # It works but identicial to previous plot
        #############################################################################
        # data    = pd_exerpt['Diff cases'].tolist()
        # data[0] = data[1]
        # print('data=', data)
        # dt     = 1
        # sigmas = MerweScaledSigmaPoints(n=1, alpha=.5, beta=2., kappa=1.) #1-3.)
        # ukf    = UKF(dim_x=1, dim_z=1, fx=fR1, hx=hR1, dt=dt, points=sigmas)
        # # Filter init
        # ukf.x[0] = data[0]
        # ukf.Q    = np.diag([30.])
        # ukf.R    = np.diag([170.])
        # if verbose>1:
        #     print('ukf.x[0]=', ukf.x[0])
        #     print('ukf.R   =', ukf.R)
        #     print('ukf.Q   =', ukf.Q)

        # # UKF filtering and smoothing, batch mode
        # diffR1filt, _ = ukf.batch_filter(data)
        # pd_exerpt['diffR1 filt'] = diffR1filt
        # PlotData(pd_exerpt, titre=titre, filenameFig=prefFig+'diffcases_filt'+HeadData[0]+'.png', y=['Diff cases', 'diffR1 filt'], color=['red', 'darkred'], Dates=DatesString)

        # F filtering by UKF
        #############################################################################
        data = pd_exerpt[HeadData[1]].tolist()
        dt = 1
        sigmas = MerweScaledSigmaPoints(n=1, alpha=.5, beta=2.,
                                        kappa=0.)  #1-3.)
        ukf = UKF(dim_x=1, dim_z=1, fx=fF, hx=hF, dt=dt, points=sigmas)
        # Filter init
        ukf.x[0] = data[0]
        ukf.Q = np.diag([15.])
        ukf.R = np.diag([100.])
        if verbose > 1:
            print('ukf.x[0]=', ukf.x[0])
            print('ukf.R   =', ukf.R)
            print('ukf.Q   =', ukf.Q)

        # UKF filtering and smoothing, batch mode
        Ffilt, _ = ukf.batch_filter(data)

        # plotting
        pd_exerpt[HeadData[1] + ' filt'] = Ffilt
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + 'filt' + HeadData[1] + '.png',
                 y=[HeadData[1], HeadData[1] + ' filt'],
                 color=['gray', 'black'],
                 Dates=DatesString)
        pd_exerpt['Diff ' + HeadData[1] + ' filt'] = pd_exerpt[HeadData[1] +
                                                               ' filt'].diff()
        PlotData(pd_exerpt,
                 titre=titre,
                 filenameFig=prefFig + 'diff_filt' + HeadData[1] + '.png',
                 y=['Diff ' + HeadData[1], 'Diff ' + HeadData[1] + ' filt'],
                 color=['gray', 'black'],
                 Dates=DatesString)

        # R1 and F simultaneous filtering by UKF
        #############################################################################
        data = pd_exerpt[[HeadData[0], HeadData[1]]].to_numpy(copy=True)
        dt = 1
        sigmas = MerweScaledSigmaPoints(n=2, alpha=.5, beta=2.,
                                        kappa=1.)  #1-3.)
        ukf = UKF(dim_x=2, dim_z=2, fx=fR1F, hx=hR1F, dt=dt, points=sigmas)
        # Filter init
        ukf.x[:] = data[0, :]
        ukf.Q = np.diag([30., 15.])
        ukf.R = np.diag([170., 100.])
        if verbose > 1:
            print('ukf.x[:]=', ukf.x[:])
            print('ukf.R   =', ukf.R)
            print('ukf.Q   =', ukf.Q)

        # UKF filtering and smoothing, batch mode
        R1Ffilt, _ = ukf.batch_filter(data)

        # plotting
        pd_exerpt[HeadData[0] + ' filtboth'] = R1Ffilt[:, 0]
        pd_exerpt[HeadData[1] + ' filtboth'] = R1Ffilt[:, 1]
        PlotData(pd_exerpt, titre=titre, filenameFig=prefFig+'filtboth'+HeadData[0]+HeadData[1]+'.png', \
                y=[HeadData[0], HeadData[0]+' filtboth', HeadData[1], HeadData[1]+' filtboth'], color=['red', 'darkred', 'gray', 'black'], Dates=DatesString)

        pd_exerpt['Diff ' + HeadData[0] +
                  ' filtboth'] = pd_exerpt[HeadData[0] + ' filtboth'].diff()
        pd_exerpt['Diff ' + HeadData[1] +
                  ' filtboth'] = pd_exerpt[HeadData[1] + ' filtboth'].diff()
        PlotData(pd_exerpt, titre=titre, filenameFig=prefFig+'diff_filt'+HeadData[0]+HeadData[1]+'.png', \
                y=['Diff '+HeadData[0], 'Diff '+HeadData[0]+' filtboth', 'Diff '+HeadData[1], 'Diff '+HeadData[1]+' filtboth', ], color=['red', 'darkred', 'gray', 'black'], Dates=DatesString)
예제 #3
0
def fit(sysargv):
    """
		Program to process Covid Data.
 
		:Example:

		For countries (European database)
		>> python ProcessSEIR1R2D.py 
		>> python ProcessSEIR1R2D.py France 0 1 0 0 1 1
		>> python ProcessSEIR1R2D.py France 2 3 8 0 1 1          # 3 périodes pour les femmes en France avec un décalage de 8 jours
		>> python ProcessSEIR1R2D.py France,Germany 1 1 0 0 1 1 # 1 période pour les hommes francais et les hommes allemands 

		For French Region (French database)
		>> python ProcessSEIR1R2D.py FRANCE,D69         0 -1 13 0 1 1 # Code Insee Dpt 69 (Rhône)
		>> python ProcessSEIR1R2D.py FRANCE,R84         0 -1 13 0 1 1 # Tous les dpts de la Région dont le code Insee est 
		>> python ProcessSEIR1R2D.py FRANCE,R32+        0 -1 13 0 1 1 # Somme de tous les dpts de la Région 32 (Hauts-de-France)
		>> python ProcessSEIR1R2D.py FRANCE,MetropoleD  0 -1 13 0 1 1 # Tous les départements de la France métropolitaine
		>> python ProcessSEIR1R2D.py FRANCE,MetropoleD+ 0 -1 13 0 1 1 # Toute la France métropolitaine (en sommant les dpts)
		>> python ProcessSEIR1R2D.py FRANCE,MetropoleR+ 0 -1 13 0 1 1 # Somme des dpts de toutes les régions françaises
		Toute combinaison est possible de lieu : exemple FRANCE,R32+,D05,R84
		
		argv[1] : List of countries (ex. France,Germany,Italy), or see above.          Default: France 
		argv[2] : Sex (male:1, female:2, male+female:0). Only for french database      Default: 0 
		argv[3] : Periods ('1' -> 1 period ('all-in-on'), '!=1' -> severall periods).  Default: -1
		argv[4] : Delay (in days).                                                     Default: 13
		argv[5] : UKF filtering of data (0/1).                                         Default: 0
		argv[6] : Verbose level (debug: 3, ..., almost mute: 0).                       Default: 1
		argv[7] : Plot graphique (0/1).                                                Default: 1
	"""

    #Austria,Belgium,Croatia,Czechia,Finland,France,Germany,Greece,Hungary,Ireland,Italy,Lithuania,Poland,Portugal,Romania,Serbia,Spain,Switzerland,Ukraine
    #Austria,Belgium,Croatia,Czechia,Finland,France,Germany,Greece,Hungary,Ireland,Italy,Poland,Portugal,Romania,Serbia,Spain,Switzerland,Ukraine
    # Il y a 18 pays

    if len(sysargv) > 7:
        print('  CAUTION : bad number of arguments - see help')
        exit(1)

    # Constantes
    ######################################################@
    fileLocalCopy = True  # if we upload the file from the url (to get latest data) or from a local copy file
    readStartDateStr = "2020-03-01"  # "2020-03-01" Le 8 mars, pour inclure un grand nombre de pays européens dont la date de premier était postérieur au 1er mars
    readStopDateStr = None
    recouvrement = -1
    dt = 1
    France = 'France'
    thresholdSignif = 1.5E-6

    # Interpetation of arguments - reparation
    ######################################################@

    # Default value for parameters
    listplaces = ['France']
    sexe, sexestr = 0, 'male+female'
    nbperiodes = -1
    decalage = 13
    UKF_filt = False
    verbose = 1
    plot = True

    # Parameters from argv
    if len(sysargv) > 0: liste = list(sysargv[0].split(','))
    if len(sysargv) > 1: sexe = int(sysargv[1])
    if len(sysargv) > 2: nbperiodes = int(sysargv[2])
    if len(sysargv) > 3: decalage = int(sysargv[3])
    if len(sysargv) > 4 and int(sysargv[4]) == 1: UKF_filt = True
    if len(sysargv) > 5: verbose = int(sysargv[5])
    if len(sysargv) > 6 and int(sysargv[6]) == 0: plot = False
    if nbperiodes == 1:
        decalage = 0  # nécessairement pas de décalage (on compense le recouvrement)
    if sexe not in [0, 1, 2]:
        sexe, sexestr = 0, 'male+female'  # sexe indiférencié
    if sexe == 1: sexestr = 'male'
    if sexe == 2: sexestr = 'female'

    listplaces = []
    listnames = []
    if liste[0] == 'FRANCE':
        FrDatabase = True
        liste = liste[1:]
        for el in liste:
            l, n = getPlace(el)
            if el == 'MetropoleR+':
                for l1, n1 in zip(l, n):
                    listplaces.extend(l1)
                    listnames.extend([n1])
            else:
                listplaces.extend(l)
                listnames.extend(n)
    else:
        listplaces = liste[:]
        FrDatabase = False

    if verbose > 0:
        print('  Full command line : ' + sysargv[0] + ' ' + str(nbperiodes) +
              ' ' + str(decalage) + ' ' + str(UKF_filt) + ' ' + str(verbose) +
              ' ' + str(plot),
              flush=True)

    # Data reading to get first and last date available in the data set
    ######################################################@
    if FrDatabase == True:
        pd_exerpt, HeadData, N, readStartDateStr, readStopDateStr, _ = readDataFrance(
            ['D69'],
            readStartDateStr,
            readStopDateStr,
            fileLocalCopy,
            sexe,
            verbose=0)
    else:
        pd_exerpt, HeadData, N, readStartDateStr, readStopDateStr, _ = readDataEurope(
            France,
            readStartDateStr,
            readStopDateStr,
            fileLocalCopy,
            verbose=0)
    dataLength = pd_exerpt.shape[0]

    readStartDate = datetime.strptime(readStartDateStr, strDate)
    if readStartDate < pd_exerpt.index[0]:
        readStartDate = pd_exerpt.index[0]
        readStartDateStr = pd_exerpt.index[0].strftime(strDate)
    readStopDate = datetime.strptime(readStopDateStr, strDate)
    if readStopDate < pd_exerpt.index[-1]:
        readStopDate = pd_exerpt.index[-1]
        readStopDateStr = pd_exerpt.index[-1].strftime(strDate)

    dataLength = pd_exerpt.shape[0]
    if verbose > 1:
        print('readStartDateStr=', readStartDateStr, ', readStopDateStr=',
              readStopDateStr)
        print('readStartDate   =', readStartDate, ', readStopDate   =',
              readStopDate)
        print('dataLength      =', dataLength)
        #input('pause')

    # Collections of data return by this function
    modelSEIR1R2D = np.zeros(shape=(len(listplaces), dataLength, 6))
    data_deriv = np.zeros(shape=(len(listplaces), dataLength, 2))
    modelR1_deriv = np.zeros(shape=(len(listplaces), dataLength, 2))
    data_all = np.zeros(shape=(len(listplaces), dataLength, 2))
    modelR1_all = np.zeros(shape=(len(listplaces), dataLength, 2))
    Listepd = []
    ListetabParamModel = []

    # data observed
    data = np.zeros(shape=(dataLength, 2))

    # Paramètres sous forme de chaines de caractères
    ListeTextParam = []
    ListeDateI0 = []

    # Loop on the places to process
    for indexplace, place in enumerate(listplaces):

        # Get the full name of the place to process, and the special dates corresponding to the place
        if FrDatabase == True:
            placefull = 'France-' + listnames[indexplace][0]
            DatesString = readDates(France, verbose)
        else:
            placefull = place
            DatesString = readDates(place, verbose)

        print('PROCESSING of', placefull, 'in', listnames)

        # data reading of the observations
        #############################################################################
        if FrDatabase == True:
            pd_exerpt, HeadData, N, readStartDateStr, readStopDateStr, dateFirstNonZeroStr = readDataFrance(
                place,
                readStartDateStr,
                readStopDateStr,
                fileLocalCopy,
                sexe,
                verbose=0)
        else:
            pd_exerpt, HeadData, N, readStartDateStr, readStopDateStr, dateFirstNonZeroStr = readDataEurope(
                place,
                readStartDateStr,
                readStopDateStr,
                fileLocalCopy,
                verbose=0)

        shift_0value = getNbDaysBetweenDateFromString(readStartDateStr,
                                                      dateFirstNonZeroStr)

        # UKF Filtering ?
        if UKF_filt == True:
            data2Filt = pd_exerpt[[HeadData[0],
                                   HeadData[1]]].to_numpy(copy=True)
            sigmas = MerweScaledSigmaPoints(n=2, alpha=.5, beta=2.,
                                            kappa=1.)  #1-3.)
            ukf = UKF(dim_x=2, dim_z=2, fx=fR1D, hx=hR1D, dt=dt, points=sigmas)
            # Filter init
            ukf.x[:] = data2Filt[0, :]
            ukf.Q = np.diag([30., 15.])
            ukf.R = np.diag([170., 100.])
            if verbose > 1:
                print('ukf.x[:]=', ukf.x[:])
                print('ukf.R   =', ukf.R)
                print('ukf.Q   =', ukf.Q)

            # UKF filtering and smoothing, batch mode
            R1Ffilt, _ = ukf.batch_filter(data2Filt)
            HeadData[0] += ' filt'
            HeadData[1] += ' filt'
            pd_exerpt[HeadData[0]] = R1Ffilt[:, 0]
            pd_exerpt[HeadData[1]] = R1Ffilt[:, 1]

        # Get the list of dates to process
        ListDates, ListDatesStr = GetPairListDates(readStartDate, readStopDate,
                                                   DatesString, decalage,
                                                   nbperiodes, recouvrement)
        if verbose > 1:
            #print('ListDates   =', ListDates)
            print('ListDatesStr=', ListDatesStr)
            #input('pause')

        # Solveur edo
        solveur = SolveEDO_SEIR1R2D(N, dt, verbose)
        indexdata = solveur.indexdata
        E0, I0, R10, R20, D0 = 0, 1, 0, 0, 0

        # Repertoire des figures
        if plot == True:
            repertoire = getRepertoire(
                UKF_filt, './figures/SEIR1R2D_UKFilt/' + placefull + '/sexe_' +
                str(sexe) + '_delay_' + str(decalage), './figures/SEIR1R2D/' +
                placefull + '/sexe_' + str(sexe) + '_delay_' + str(decalage))
            prefFig = repertoire + '/Process_'

        # Remise à 0 des données
        data.fill(0.)

        # Boucle pour traiter successivement les différentes fenêtres
        ###############################################################

        ListeTextParamPlace = []
        ListetabParamModelPlace = []
        ListeEQM = []

        DEGENERATE_CASE = False

        for i in range(len(ListDatesStr)):

            # dates of the current period
            fitStartDate, fitStopDate = ListDates[i]
            fitStartDateStr, fitStopDateStr = ListDatesStr[i]

            # Est-on dans un CAS degénéré?
            # print(getNbDaysBetweenDateFromString(dateFirstNonZeroStr, fitStopDateStr))
            if getNbDaysBetweenDateFromString(
                    dateFirstNonZeroStr, fitStopDateStr
            ) < 5:  # Il faut au moins 5 données pour fitter
                DEGENERATE_CASE = True

            if i > 0:
                DatesString.addOtherDates(fitStartDateStr)

            # Récupérations des données observées
            dataLengthPeriod = 0
            indMinPeriod = (fitStartDate - readStartDate).days

            for j, z in enumerate(pd_exerpt.loc[
                    fitStartDateStr:addDaystoStrDate(fitStopDateStr, -1),
                    HeadData[0]]):
                data[indMinPeriod + j, 0] = z
                dataLengthPeriod += 1
            for j, z in enumerate(pd_exerpt.loc[
                    fitStartDateStr:addDaystoStrDate(fitStopDateStr, -1),
                    HeadData[1]]):
                data[indMinPeriod + j, 1] = z
            slicedata = slice(indMinPeriod, indMinPeriod + dataLengthPeriod)
            slicedataderiv = slice(slicedata.start + 1, slicedata.stop)
            if verbose > 0:
                print('  dataLength      =', dataLength)
                print('  indMinPeriod    =', indMinPeriod)
                print('  dataLengthPeriod=', dataLengthPeriod)
                print('  fitStartDateStr =', fitStartDateStr)
                print('  fitStopDateStr  =', fitStopDateStr)
                #input('attente')

            # Set initialisation data for the solveur
            ############################################################################

            # paramètres initiaux à optimiser
            if i == 0:
                datelegend = fitStartDateStr
                # ts=getNbDaysBetweenDateFromString(DatesString.listFirstCaseDates[0], readStartDateStr)
                # En premiere approximation, on prend la date du premier cas estimé pour le pays (même si c'est faux pour les régions et dpts)
                ts = getNbDaysBetweenDateFromString(
                    DatesString.listFirstCaseDates[0], dateFirstNonZeroStr)
                if ts < 0:
                    continue  # On passe au pays suivant
                if nbperiodes != 1:  # pour plusieurs périodes
                    #l, b0, c0, f0 = 0.255, 1./5.2, 1./12, 0.08
                    #a0 = (l+c0)*(1.+l/b0)
                    #a0, b0, c0, f0, mu0, xi0 = 0.55, 0.34, 0.12, 0.25, 0.0005, 0.0001
                    a0, b0, c0, f0, mu0, xi0 = 0.60, 0.55, 0.30, 0.50, 0.0005, 0.0001
                    T = 150
                else:  # pour une période
                    #a0, b0, c0, f0, mu0, xi0  = 0.10, 0.29, 0.10, 0.0022, 0.00004, 0.
                    a0, b0, c0, f0, mu0, xi0 = 0.70, 0.25, 0.05, 0.003, 0.0005, 0.0001
                    T = 350

            if i == 1 or i == 2:
                datelegend = None

                _, a0, b0, c0, f0, mu0, xi0 = solveur.modele.getParam()
                R10 = int(data[indMinPeriod,
                               0])  # on corrige R1 à la valeur numérique
                F0 = int(data[indMinPeriod,
                              1])  # on corrige F à la valeur numérique
                if i == 1:
                    a0 /= 4.  # le confinement réduit drastiquement (pour aider l'optimisation)
                T = 120
                ts = 0

            time = np.linspace(0, T - 1, T)

            solveur.modele.setParam(N=N,
                                    a=a0,
                                    b=b0,
                                    c=c0,
                                    f=f0,
                                    mu=mu0,
                                    xi=xi0)
            solveur.setParamInit(N=N, E0=E0, I0=I0, R10=R10, R20=R20, D0=D0)

            # Before optimization
            ###############################

            # Solve ode avant optimization
            sol_ode = solveur.solveEDO(time)
            # calcul time shift initial (ts) with respect to data
            if i == 0:
                ts = solveur.compute_tsfromEQM(data[slicedata, :], T,
                                               indexdata)
            else:
                solveur.TS = ts = 0
            sliceedo = slice(ts, min(ts + dataLengthPeriod, T))
            if verbose > 0:
                print(solveur)
                print('  ts=' + str(ts))

            # plot
            if plot == True and DEGENERATE_CASE == False:
                commontitre = placefull + '- Period ' + str(i) + '\\' + str(
                    len(ListDatesStr) - 1
                ) + ' - [' + fitStartDateStr + '\u2192' + addDaystoStrDate(
                    fitStopDateStr, -1)
                if sewe == 0:
                    titre = commontitre + '] (Delay (delta)=' + str(
                        decalage) + ')'
                else:
                    titre = commontitre + '] (Sex=', +sexestr + ', Delay (delta)=' + str(
                        decalage) + ')'

                listePlot = indexdata
                filename = prefFig + str(decalage) + '_Period' + str(
                    i) + '_' + ''.join(map(str, listePlot)) + 'Init.png'
                solveur.plotEDO(filename,
                                titre,
                                sliceedo,
                                slicedata,
                                plot=listePlot,
                                data=data,
                                text=solveur.getTextParam(datelegend,
                                                          Period=i))

            # Parameters optimization
            ############################################################################

            solveur.paramOptimization(
                data[slicedata, :],
                time)  # version lorsque ts est calculé automatiquement
            #solveur.paramOptimization(data[slicedata, :], time, ts) # version lorsque l'on veut fixer ts
            _, a1, b1, c1, f1, mu1, xi1 = solveur.modele.getParam()
            R0 = solveur.modele.getR0()
            if verbose > 0:
                print('Solver' 's state after optimization=', solveur)
                print('  Reproductivité après: ', R0)

            # After optimization
            ###############################

            # Solve ode avant optimization
            sol_ode = solveur.solveEDO(time)
            # calcul time shift with respect to data
            if i == 0:
                ts = solveur.compute_tsfromEQM(data[slicedata, :], T,
                                               indexdata)
            else:
                solveur.TS = ts = 0
            sliceedo = slice(ts, min(ts + dataLengthPeriod, T))
            sliceedoderiv = slice(sliceedo.start + 1, sliceedo.stop)
            if verbose > 0:
                print(solveur)
                print('  ts=' + str(ts))
            if i == 0:  # on se souvient de la date du premier infesté
                dateI0 = addDaystoStrDate(fitStartDateStr, -ts + shift_0value)
                if verbose > 2:
                    print('dateI0=', dateI0)
                    input('attente')

            # sauvegarde des param (tableau et texte)
            seuil = (data[slicedata.stop - 1, 0] - data[slicedata.start, 0]
                     ) / getNbDaysBetweenDateFromString(
                         fitStartDateStr, fitStopDateStr) / N
            #print('seuil=', seuil)
            #print('DEGENERATE_CASE=', DEGENERATE_CASE)
            if DEGENERATE_CASE == True:
                ROsignificatif = False
                ListetabParamModelPlace.append(
                    [-1., -1., -1., -1., -1., -1., -1.])
            else:
                if seuil < thresholdSignif:
                    ROsignificatif = False
                    ListetabParamModelPlace.append(
                        [a1, b1, c1, f1, mu1, xi1, -1.])
                else:
                    ROsignificatif = True
                    ListetabParamModelPlace.append(
                        [a1, b1, c1, f1, mu1, xi1, R0])
                # print('seuil=', seuil)
                # print('ROsignificatif=', ROsignificatif)
                # print('R0=', R0)
                # input('pause')

            ListeTextParamPlace.append(
                solveur.getTextParamWeak(datelegend, ROsignificatif, Period=i))

            data_deriv_period = (
                data[slicedataderiv, :] -
                data[slicedataderiv.start - 1:slicedataderiv.stop - 1, :]) / dt
            modelR1_deriv_period = (
                sol_ode[sliceedoderiv, indexdata] -
                sol_ode[sliceedoderiv.start - 1:sliceedoderiv.stop - 1,
                        indexdata]) / dt
            data_all_period = data[slicedataderiv, :]
            modelR1_all_period = sol_ode[sliceedoderiv, indexdata]

            if plot == True and DEGENERATE_CASE == False:
                commontitre = placefull + '- Period ' + str(i) + '\\' + str(
                    len(ListDatesStr) - 1
                ) + ' - [' + fitStartDateStr + '\u2192' + addDaystoStrDate(
                    fitStopDateStr, -1)
                if sexe == 0:
                    titre = commontitre + '] (Delay (delta)=' + str(
                        decalage) + ')'
                else:
                    titre = commontitre + '] (Sex=', +sexestr + ', Delay (delta)=' + str(
                        decalage) + ')'

                # listePlot = [0,1,2,3,4,5]
                # filename  = prefFig + str(decalage) + '_Period' + str(i) + '_' + ''.join(map(str, listePlot)) + '.png'
                # solveur.plotEDO(filename, titre, sliceedo, slicedata, plot=listePlot, data=data, text=solveur.getTextParam(datelegend, ROsignificatif, Period=i))
                listePlot = [1, 2, 3, 5]
                filename = prefFig + str(decalage) + '_Period' + str(
                    i) + '_' + ''.join(map(str, listePlot)) + 'Final.png'
                solveur.plotEDO(filename,
                                titre,
                                sliceedo,
                                slicedata,
                                plot=listePlot,
                                data=data,
                                text=solveur.getTextParam(datelegend,
                                                          ROsignificatif,
                                                          Period=i))
                listePlot = indexdata
                filename = prefFig + str(decalage) + '_Period' + str(
                    i) + '_' + ''.join(map(str, listePlot)) + 'Final.png'
                solveur.plotEDO(filename,
                                titre,
                                sliceedo,
                                slicedata,
                                plot=listePlot,
                                data=data,
                                text=solveur.getTextParam(datelegend,
                                                          ROsignificatif,
                                                          Period=i))

                # dérivée  numérique de R1 et F
                filename = prefFig + str(decalage) + '_Period' + str(
                    i) + '_' + ''.join(map(str, listePlot)) + 'Deriv.png'
                solveur.plotEDO_deriv(filename,
                                      titre,
                                      sliceedoderiv,
                                      slicedataderiv,
                                      data_deriv_period,
                                      indexdata,
                                      text=solveur.getTextParam(datelegend,
                                                                ROsignificatif,
                                                                Period=i))

            # sol_ode_withSwitch = solveur.solveEDO_withSwitch(T, timeswitch=ts+dataLengthPeriod)

            # ajout des données dérivées
            data_all[indexplace, slicedataderiv, :] = data_all_period
            modelR1_all[indexplace, slicedataderiv, :] = modelR1_all_period
            data_deriv[indexplace, slicedataderiv, :] = data_deriv_period
            modelR1_deriv[indexplace, slicedataderiv, :] = modelR1_deriv_period

            # ajout des SEIR1R2D
            modelSEIR1R2D[indexplace,
                          slicedata.start:slicedata.stop, :] = sol_ode[
                              ts:ts + sliceedo.stop - sliceedo.start, :]

            # preparation for next iteration
            _, E0, I0, R10, R20, D0 = map(
                int, sol_ode[ts + dataLengthPeriod + recouvrement, :])
            #print('A LA FIN : E0, I0, R10, R20, D0=', E0, I0, R10, R20, D0)

            if verbose > 1:
                input('next step')

        Listepd.append(pd_exerpt)
        ListeDateI0.append(dateI0)

        # calcul de l'EQM sur les données (et non sur les dérivées des données)
        #EQM = mean_squared_error(data_deriv[indexplace, :], modelR1_deriv[indexplace, :])
        EQM = mean_squared_error(data_all[indexplace, :],
                                 modelR1_all[indexplace, :])
        ListeEQM.append(EQM)

        # udpate des listes pour transmission
        ListeTextParam.append(ListeTextParamPlace)
        ListetabParamModel.append(ListetabParamModelPlace)

    return modelSEIR1R2D, ListeTextParam, Listepd, data_deriv, modelR1_deriv, ListetabParamModel, ListeEQM, ListeDateI0
예제 #4
0
def main(sysargv):
    """
		:Example:

		For countries (European database)
		>> python Fit.py 
		>> python Fit.py France      0 SEIR1R2 18 0 1 1
		>> python Fit.py Italy,Spain 1 SEIR1R2D 18 1 1 1 # Italy and Spain, with UKF filtering

		For French Region (French database)
		>> python Fit.py FRANCE,D69         0 SEIR1R2  18 0 1 1 # Code Insee Dpt 69 (Rhône)
		>> python Fit.py FRANCE,R84         0 SEIR1R2  18 0 1 1 # Tous les dpts de la Région dont le code Insee est 
		>> python Fit.py FRANCE,R32+        0 SEIR1R2  18 0 1 1 # Somme de tous les dpts de la Région 32 (Hauts de F
		>> python Fit.py FRANCE,MetropoleD  0 SEIR1R2  18 0 1 1 # Tous les départements de la France métropolitaine
		>> python Fit.py FRANCE,MetropoleD+ 0 SEIR1R2  18 0 1 1 # Toute la France métropolitaine (en sommant les dpts)
		>> python Fit.py FRANCE,MetropoleR+ 0 SEIR1R2  18 0 1 1 # Somme des dpts de toutes les régions françaises
		Toute combinaison de lieux est possible : exemple FRANCE,R32+,D05,R84
		
		argv[1] : List of countries (ex. France,Germany,Italy), or see above.  Default: France 
		argv[2] : Sex (male:1, female:2, male+female:0). Only for french database     Default: 0 
		argv[3] : EDO model (SEIR1R2 or SEIR1R2D).                             Default: SEIR2R2         
		argv[4] : Delay (in days).                                             Default: 18
		argv[5] : UKF filtering of data (0/1).                                 Default: 0
		argv[6] : Verbose level (debug: 3, ..., almost mute: 0).               Default: 1
		argv[7] : Plot graphique (0/1).                                        Default: 1
	"""

    #Austria,Belgium,Croatia,Czechia,Finland,France,Germany,Greece,Hungary,Ireland,Italy,Lithuania,Poland,Portugal,Romania,Serbia,Spain,Switzerland,Ukraine
    #Austria,Belgium,Croatia,Czechia,Finland,France,Germany,Greece,Hungary,Ireland,Italy,Poland,Portugal,Romania,Serbia,Spain,Switzerland,Ukraine
    # Il y a 18 pays

    # Interpetation of arguments - reparation
    ######################################################@

    if len(sysargv) > 8:
        print('  CAUTION : bad number of arguments - see help')
        exit(1)

    # Default value for parameters
    places = 'France'
    sexe, sexestr = 0, 'male+female'
    listplaces = list(places.split(','))
    modeleString = 'SEIR1R2'
    decalage3P = 18
    UKF_filt, UKF_filt01 = False, 0
    verbose = 1
    plot = True
    France = 'France'

    # Parameters from argv
    if len(sysargv) > 1:
        places, liste = sysargv[1], list(sysargv[1].split(','))
    if len(sysargv) > 2: sexe = int(sysargv[2])
    if len(sysargv) > 3: modeleString = sysargv[3]
    if len(sysargv) > 4: decalage3P = int(sysargv[4])
    if len(sysargv) > 5 and int(sysargv[5]) == 1:
        UKF_filt, UKF_filt01 = True, 1
    if len(sysargv) > 6: verbose = int(sysargv[6])
    if len(sysargv) > 7 and int(sysargv[7]) == 0: plot = False
    if sexe not in [0, 1, 2]:
        sexe, sexestr = 0, 'male+female'  # sexe indiférencié
    if sexe == 1: sexestr = 'male'
    if sexe == 2: sexestr = 'female'

    listplaces = []
    listnames = []
    if liste[0] == 'FRANCE':
        FrDatabase = True
        liste = liste[1:]
        for el in liste:
            l, n = getPlace(el)
            if el == 'MetropoleR+':
                for l1, n1 in zip(l, n):
                    listplaces.extend(l1)
                    listnames.extend([n1])
            else:
                listplaces.extend(l)
                listnames.extend(n)
        places = [el[0] for el in listnames]
        places = 'FRANCE,' + ','.join(places)
    else:
        listplaces = liste[:]
        FrDatabase = False

    # le modèle à traiter (SEIR1R2 or SEIR1R2D)
    if modeleString == 'SEIR1R2':
        fit = fitProcessSEIR1R2
    elif modeleString == 'SEIR1R2D':
        fit = fitProcessSEIR1R2D
    else:
        print('Wrong EDO model, only SEIR1R2 or SEIR1R2D available!')
        exit(1)

    if verbose > 0:
        print('  Full command line : ' + sysargv[0] + ' ' + places + ' ' +
              str(sexe) + ' ' + modeleString + ' ' + str(decalage3P) + ' ' +
              str(UKF_filt) + ' ' + str(verbose) + ' ' + str(plot),
              flush=True)

    # fit avec 3 périodes + décalage
    ##################################
    nbperiodes = -1

    model_piecewise, ListeTextParamPlace_piecewise, liste_pd_piecewise, data_deriv_piecewise, model_deriv_piecewise, _, _, _, ListeDateI0 = \
      fit([places, sexe, nbperiodes, decalage3P, UKF_filt01, 0, 0])

    ListeTestPlace = []
    for indexplace in range(len(listplaces)):
        texteplace = ''
        for texte in ListeTextParamPlace_piecewise[indexplace]:
            texteplace += '\n' + texte
        ListeTestPlace.append(texteplace)

    # Plot the multi-period strategy
    ##################################
    for indexplace, place in enumerate(listplaces):

        # Get the full name of the place to process, and the special dates corresponding to the place
        if FrDatabase == True:
            if 'MetropoleD+' in listnames[indexplace][0]:
                placefull = 'France'
            else:
                placefull = 'France-' + listnames[indexplace][0]
        else:
            placefull = place

        # Repertoire des figures
        repertoire = getRepertoire(
            UKF_filt, './figures/' + modeleString + '_UKFilt/' + placefull +
            '/sexe_' + str(sexe) + '_delay_' + str(decalage3P),
            './figures/' + modeleString + '/' + placefull + '/sexe_' +
            str(sexe) + '_delay_' + str(decalage3P))
        prefFig = repertoire + '/Fit_'

        # Preparation plot pandas
        listheader = list(liste_pd_piecewise[indexplace])

        if FrDatabase == True:
            DatesString = readDates(France, verbose)
        else:
            DatesString = readDates(place, verbose)

        #####################################################@
        # DERIVEES
        # on ajoute les dérivées numériques des cas et des morts
        liste_pd_piecewise[indexplace]['dcases'] = liste_pd_piecewise[
            indexplace][listheader[0]].diff()
        liste_pd_piecewise[indexplace]['ddeaths'] = liste_pd_piecewise[
            indexplace][listheader[1]].diff()
        liste_pd_piecewise[indexplace][
            'dcasesplusdeaths'] = liste_pd_piecewise[indexplace][
                listheader[2]].diff()
        longueur = len(liste_pd_piecewise[indexplace].loc[:, ('dcases')])

        # on ajoute les dérivées numériques des cas et des morts
        liste_pd_piecewise[indexplace].loc[:, (
            'mc_piecewise')] = model_deriv_piecewise[indexplace, 0:longueur, 0]
        liste_pd_piecewise[indexplace].loc[:, (
            'mc_piecewise_residual')] = liste_pd_piecewise[indexplace].loc[:, (
                'mc_piecewise')] - liste_pd_piecewise[indexplace].loc[:, (
                    'dcasesplusdeaths')]
        if modeleString == 'SEIR1R2D':
            liste_pd_piecewise[indexplace].loc[:, (
                'md_piecewise')] = model_deriv_piecewise[indexplace,
                                                         0:longueur, 1]
            liste_pd_piecewise[indexplace].loc[:, (
                'md_piecewise_residual'
            )] = liste_pd_piecewise[indexplace].loc[:, (
                'md_piecewise')] - liste_pd_piecewise[indexplace].loc[:, (
                    'ddeaths')]

        # Dessin des dérivées
        filename = prefFig + str(decalage3P) + '_Diff_Piecewise.png'
        if sexe == 0:
            title = placefull + ' - Delay (delta)=' + str(
                decalage3P) + ' day(s)'
        else:
            title = placefull + ' - Sex=' + sexestr + ', Delay (delta)=' + str(
                decalage3P) + ' day(s)'
        if modeleString == 'SEIR1R2':
            listPlots = ['dcasesplusdeaths', 'mc_piecewise']
        if modeleString == 'SEIR1R2D':
            listPlots = ['dcases', 'mc_piecewise', 'ddeaths', 'md_piecewise']
        PlotFitPiecewise(modeleString,
                         liste_pd_piecewise[indexplace],
                         title,
                         filename,
                         y=listPlots,
                         Dates=DatesString,
                         textannotation=ListeTestPlace[indexplace])

        # Dessin des résidus des dérivées
        filename = prefFig + str(decalage3P) + '_Diff_PiecewiseResiduals.png'
        if sexe == 0:
            title = placefull + ' - Delay (delta)=' + str(
                decalage3P) + ' day(s)'
        else:
            title = placefull + ' - Sex=' + sexestr + ', Delay (delta)=' + str(
                decalage3P) + ' day(s)'

        if modeleString == 'SEIR1R2':
            listPlots = ['mc_piecewise_residual']
        if modeleString == 'SEIR1R2D':
            listPlots = ['mc_piecewise_residual', 'md_piecewise_residual']
        PlotFitPiecewiseResidual(modeleString,
                                 liste_pd_piecewise[indexplace],
                                 title,
                                 filename,
                                 y=listPlots,
                                 Dates=DatesString,
                                 textannotation=ListeTestPlace[indexplace])

    # Plot the three SEIR1R2
    ##################################

    for indexplace, place in enumerate(listplaces):

        # Get the full name of the place to process, and the special dates corresponding to the place
        if FrDatabase == True:
            if 'MetropoleD+' in listnames[indexplace][0]:
                placefull = 'France'
            else:
                placefull = 'France-' + listnames[indexplace][0]
        else:
            placefull = place

        # Repertoire des figures
        repertoire = getRepertoire(
            UKF_filt, './figures/' + modeleString + '_UKFilt/' + placefull +
            '/sexe_' + str(sexe) + '_delay_' + str(decalage3P),
            './figures/' + modeleString + '/' + placefull + '/sexe_' +
            str(sexe) + '_delay_' + str(decalage3P))
        prefFig = repertoire + '/Fit_'

        # Preparation plot pandas
        listheader = list(liste_pd_piecewise[indexplace])
        longueur = len(liste_pd_piecewise[indexplace].loc[:, (listheader[0])])

        if FrDatabase == True:
            DatesString = readDates(France, verbose)
        else:
            DatesString = readDates(place, verbose)

        liste_pd_piecewise[indexplace].loc[:, (
            'S(t)')] = model_piecewise[indexplace, :, 0]
        liste_pd_piecewise[indexplace].loc[:, (
            'E(t)')] = model_piecewise[indexplace, :, 1]
        liste_pd_piecewise[indexplace].loc[:, (
            'I(t)')] = model_piecewise[indexplace, :, 2]
        liste_pd_piecewise[indexplace].loc[:, (
            'R1(t)')] = model_piecewise[indexplace, :, 3]
        liste_pd_piecewise[indexplace].loc[:, (
            'R2(t)')] = model_piecewise[indexplace, :, 4]
        if modeleString == 'SEIR1R2D':
            liste_pd_piecewise[indexplace].loc[:, (
                'D(t)')] = model_piecewise[indexplace, :, 5]

        if sexe == 0:
            titre = placefull + ' - Delay (delta)=' + str(
                decalage3P) + ' day(s)'
        else:
            titre = placefull + ' - Sex=' + sexestr + ', Delay (delta)=' + str(
                decalage3P) + ' day(s)'

        listePlot = ['E(t)', 'I(t)', 'R1(t)']
        if modeleString == 'SEIR1R2D':
            listePlot.append('D(t)')
        filename = prefFig + str(decalage3P) + '_' + ''.join(
            map(str, listePlot)) + '_piecewise.png'
        PlotModel(modeleString,
                  liste_pd_piecewise[indexplace],
                  titre,
                  filename,
                  y=listePlot,
                  Dates=DatesString,
                  textannotation=ListeTestPlace[indexplace])

        listePlot = ['R1(t)']
        if modeleString == 'SEIR1R2D':
            listePlot.append('D(t)')
        filename = prefFig + str(decalage3P) + '_' + ''.join(
            map(str, listePlot)) + '_piecewise.png'

        PlotModel(modeleString,
                  liste_pd_piecewise[indexplace],
                  titre,
                  filename,
                  y=listePlot,
                  Dates=DatesString,
                  textannotation=ListeTestPlace[indexplace])