def synthetic_control(before_interv_data, after_interv_data,
                      after_interv_dates, before_interv_dates, selected_ixp,
                      other_ixps):
    singvals = 4
    trainDF = pd.DataFrame(data=before_interv_data)
    testDF = pd.DataFrame(data=after_interv_data[other_ixps])
    rscModel = RobustSyntheticControl(selected_ixp,
                                      singvals,
                                      len(trainDF),
                                      probObservation=1.0,
                                      modelType='svd',
                                      svdMethod='numpy',
                                      otherSeriesKeysArray=other_ixps)
    rscModel.fit(trainDF)
    denoisedDF = rscModel.model.denoisedDF()
    print(rscModel.model.weights)

    predictions = []
    predictions = np.dot(testDF[other_ixps], rscModel.model.weights)
    model_fit = np.dot(trainDF[other_ixps][:], rscModel.model.weights)

    plt.plot(after_interv_dates,
             predictions,
             color='red',
             label="counterfactual",
             linewidth=1)
    plt.plot(before_interv_dates,
             model_fit,
             color='green',
             label="fitted model",
             linewidth=1)
trainDF = pd.DataFrame(data=trainDataDict)
testDF = pd.DataFrame(data=testDataDict)

# model
rscModel = RobustSyntheticControl(
    basqueKey,
    singvals,
    len(trainDF),
    probObservation=1.0,
    modelType="svd",
    svdMethod="numpy",
    otherSeriesKeysArray=otherStates,
)

# fit the model
rscModel.fit(trainDF)

# save the denoised training data
denoisedDF = rscModel.model.denoisedDF()

# predict - all at once
predictions = rscModel.predict(testDF)

# plot
yearsToPlot = range(yearStart, yearTestEnd, 1)
interventionYear = yearTrainEnd - 1
plt.plot(
    yearsToPlot,
    np.append(trainMasterDF[basqueKey], testDF[basqueKey], axis=0),
    color="red",
    label="observations",
Exemple #3
0
def runAnalysis(N, T, TrainingEnd, rowRank, colRank):

    # generate metric matrices
    genFunctionOne = simpleFunctionOne
    genFunctionTwo = simpleFunctionTwo

    trueWeights = np.random.uniform(0.0, 1.0, N)
    trueWeights = trueWeights / np.sum(trueWeights)

    thetaArrayParams = np.random.uniform(0.0, 1.0, rowRank)
    rhoArrayParams = np.random.uniform(0.0, 1.0, colRank)

    rowParams = np.random.choice(thetaArrayParams, N)
    colParams = np.random.choice(rhoArrayParams, T)

    # metric 1
    (observationMatrix1, meanMatrix1, trainDF1, testDF1, meanTrainingDict1,
     meanTestDict1) = generateOneMetricMatrix(N, T, TrainingEnd, rowRank,
                                              colRank, genFunctionOne,
                                              trueWeights, rowParams,
                                              colParams)

    # metric 2
    (observationMatrix2, meanMatrix2, trainDF2, testDF2, meanTrainingDict2,
     meanTestDict2) = generateOneMetricMatrix(N, T, TrainingEnd, rowRank,
                                              colRank, genFunctionTwo,
                                              trueWeights, rowParams,
                                              colParams)

    keySeriesLabel = '0'
    otherSeriesLabels = []
    for ind in range(1, N + 1):
        otherSeriesLabels.append(str(ind))

    # RSC analysis
    singvals = 8

    ############################
    #### RSC for metric 1
    rscmodel1 = RobustSyntheticControl(keySeriesLabel,
                                       singvals,
                                       len(trainDF1),
                                       probObservation=1.0,
                                       svdMethod='numpy',
                                       otherSeriesKeysArray=otherSeriesLabels)

    # fit the model
    rscmodel1.fit(trainDF1)
    predictionsRSC1 = rscmodel1.predict(testDF1)

    rscRMSE1 = np.sqrt(
        np.mean((predictionsRSC1 - meanTestDict1[keySeriesLabel])**2))
    #print("\n\n *** RSC rmse1:")
    #print(rscRMSE1)

    ############################
    ##### RSC for metric 2
    rscmodel2 = RobustSyntheticControl(keySeriesLabel,
                                       singvals,
                                       len(trainDF2),
                                       probObservation=1.0,
                                       svdMethod='numpy',
                                       otherSeriesKeysArray=otherSeriesLabels)

    # fit the model
    rscmodel2.fit(trainDF2)
    predictionsRSC2 = rscmodel2.predict(testDF2)

    rscRMSE2 = np.sqrt(
        np.mean((predictionsRSC2 - meanTestDict2[keySeriesLabel])**2))
    #print("\n\n *** RSC rmse2:")
    #print(rscRMSE2)

    ############################
    ####  multi RSC model (combined) --
    relative_weights = [1.0, 1.0]

    # instantiate the model
    mrscmodel = MultiRobustSyntheticControl(
        2,
        relative_weights,
        keySeriesLabel,
        singvals,
        len(trainDF1),
        probObservation=1.0,
        svdMethod='numpy',
        otherSeriesKeysArray=otherSeriesLabels)

    # fit
    mrscmodel.fit([trainDF1, trainDF2])

    # predict
    combinedPredictionsArray = mrscmodel.predict(
        [testDF1[otherSeriesLabels], testDF2[otherSeriesLabels]])

    # split the predictions for the metrics
    predictionsmRSC_1 = combinedPredictionsArray[0]
    predictionsmRSC_2 = combinedPredictionsArray[1]

    # compute RMSE
    mrscRMSE1 = np.sqrt(
        np.mean((predictionsmRSC_1 - meanTestDict1[keySeriesLabel])**2))
    mrscRMSE2 = np.sqrt(
        np.mean((predictionsmRSC_2 - meanTestDict2[keySeriesLabel])**2))

    #print("\n\n *** mRSC rmse1:")
    #print(mrscRMSE1)

    #print("\n\n *** mRSC rmse2:")
    #print(mrscRMSE1)

    return ({
        "rsc1": rscRMSE1,
        "rsc2": rscRMSE2,
        "mrsc1": mrscRMSE1,
        "mrsc2": mrscRMSE2
    })
def basque(filename):
    # BASQUE COUNTRY STUDY
    df = pd.read_csv(filename)
    pivot = df.pivot_table(values='gdpcap', index='regionname', columns='year')
    pivot = pivot.drop('Spain (Espana)')
    dfBasque = pd.DataFrame(pivot.to_records())

    allColumns = dfBasque.columns.values

    states = list(np.unique(dfBasque['regionname']))
    years = np.delete(allColumns, [0])

    basqueKey = 'Basque Country (Pais Vasco)'
    states.remove(basqueKey)
    otherStates = states

    yearStart = 1955
    yearTrainEnd = 1971
    yearTestEnd = 1998

    singvals = 1
    p = 0.8

    trainingYears = []
    for i in range(yearStart, yearTrainEnd, 1):
        trainingYears.append(str(i))

    testYears = []
    for i in range(yearTrainEnd, yearTestEnd, 1):
        testYears.append(str(i))

    trainDataMasterDict = {}
    trainDataDict = {}
    testDataDict = {}
    for key in otherStates:
        series = dfBasque[dfBasque['regionname'] == key]

        trainDataMasterDict.update({key: series[trainingYears].values[0]})

        # randomly hide training data
        (trainData, pObservation) = tsUtils.randomlyHideValues(copy.deepcopy(trainDataMasterDict[key]), p)
        trainDataDict.update({key: trainData})
        testDataDict.update({key: series[testYears].values[0]})

    series = dfBasque[dfBasque['regionname'] == basqueKey]
    trainDataMasterDict.update({basqueKey: series[trainingYears].values[0]})
    trainDataDict.update({basqueKey: series[trainingYears].values[0]})
    testDataDict.update({basqueKey: series[testYears].values[0]})

    trainMasterDF = pd.DataFrame(data=trainDataMasterDict)
    trainDF = pd.DataFrame(data=trainDataDict)
    testDF = pd.DataFrame(data=testDataDict)

    # model
    rscModel = RobustSyntheticControl(basqueKey, singvals, len(trainDF), probObservation=1.0, modelType='als',
                                      otherSeriesKeysArray=otherStates)

    # fit the model
    rscModel.fit(trainDF)

    # save the denoised training data
    denoisedDF = rscModel.model.denoisedDF()

    # predict - all at once
    predictions = rscModel.predict(testDF)

    # plot
    yearsToPlot = range(yearStart, yearTestEnd, 1)
    interventionYear = yearTrainEnd - 1
    plt.plot(yearsToPlot, np.append(trainMasterDF[basqueKey], testDF[basqueKey], axis=0), color='red',
             label='observations')
    plt.plot(yearsToPlot, np.append(denoisedDF[basqueKey], predictions, axis=0), color='blue', label='predictions')
    plt.axvline(x=interventionYear, linewidth=1, color='black', label='Intervention')
    # plt.ylim((-1, 0))
    legend = plt.legend(loc='upper right', shadow=True)
    plt.title('Abadie et al. Basque Country Case Study - $p = %.2f$' % p)
    plt.show()
def prop99(filename):
    # CALIFORNIA PROP 99 STUDY
    df = pd.read_csv(filename)
    df = df[df['SubMeasureDesc'] == 'Cigarette Consumption (Pack Sales Per Capita)']
    pivot = df.pivot_table(values='Data_Value', index='LocationDesc', columns=['Year'])
    dfProp99 = pd.DataFrame(pivot.to_records())

    allColumns = dfProp99.columns.values

    states = list(np.unique(dfProp99['LocationDesc']))
    years = np.delete(allColumns, [0])

    caStateKey = 'California'
    states.remove(caStateKey)
    otherStates = states

    yearStart = 1970
    yearTrainEnd = 1989
    yearTestEnd = 2015

    singvals = 2
    p = 1.0

    trainingYears = []
    for i in range(yearStart, yearTrainEnd, 1):
        trainingYears.append(str(i))

    testYears = []
    for i in range(yearTrainEnd, yearTestEnd, 1):
        testYears.append(str(i))

    trainDataMasterDict = {}
    trainDataDict = {}
    testDataDict = {}
    for key in otherStates:
        series = dfProp99[dfProp99['LocationDesc'] == key]

        trainDataMasterDict.update({key: series[trainingYears].values[0]})

        # randomly hide training data
        (trainData, pObservation) = tsUtils.randomlyHideValues(copy.deepcopy(trainDataMasterDict[key]), p)
        trainDataDict.update({key: trainData})
        testDataDict.update({key: series[testYears].values[0]})

    series = dfProp99[dfProp99['LocationDesc'] == caStateKey]
    trainDataMasterDict.update({caStateKey: series[trainingYears].values[0]})
    trainDataDict.update({caStateKey: series[trainingYears].values[0]})
    testDataDict.update({caStateKey: series[testYears].values[0]})

    trainMasterDF = pd.DataFrame(data=trainDataMasterDict)
    trainDF = pd.DataFrame(data=trainDataDict)
    testDF = pd.DataFrame(data=testDataDict)

    # model
    rscModel = RobustSyntheticControl(caStateKey, singvals, len(trainDF), probObservation=1.0, modelType='als',
                                      otherSeriesKeysArray=otherStates)

    # fit the model
    rscModel.fit(trainDF)

    # save the denoised training data
    denoisedDF = rscModel.model.denoisedDF()

    # predict - all at once
    predictions = rscModel.predict(testDF)

    # plot
    yearsToPlot = range(yearStart, yearTestEnd, 1)
    interventionYear = yearTrainEnd - 1
    plt.plot(yearsToPlot, np.append(trainMasterDF[caStateKey], testDF[caStateKey], axis=0), color='red',
             label='observations')
    plt.plot(yearsToPlot, np.append(denoisedDF[caStateKey], predictions, axis=0), color='blue', label='predictions')
    plt.axvline(x=interventionYear, linewidth=1, color='black', label='Intervention')
    legend = plt.legend(loc='lower left', shadow=True)
    plt.title('Abadie et al. Prop 99 Case Study (CA) - $p = %.2f$' % p)
    plt.show()
Exemple #6
0
def synth_control_predictions(list_of_dfs,
                              threshold,
                              low_thresh,
                              title_text,
                              singVals=2,
                              savePlots=False,
                              ylimit=[],
                              logy=False,
                              exclude=[],
                              svdSpectrum=False,
                              showDonors=True,
                              do_only=[],
                              showstates=4,
                              animation=[],
                              figure=None,
                              axes=None,
                              donorPool=[],
                              silent=True,
                              showPlots=True,
                              mRSC=False,
                              lambdas=[1],
                              error_thresh=1,
                              yaxis='Cases',
                              FONTSIZE=20,
                              tick_spacing=30,
                              random_distribution=None):

    #print('yo', list_of_dfs,'bo')
    #print(len(list_of_dfs))
    df = list_of_dfs[0]

    if (donorPool):
        otherStates = donorPool.copy()
    else:
        sizes = df.apply(pd.Series.last_valid_index)
        sizes = sizes.fillna(0).astype(int)
        otherStates = list(sizes[sizes > threshold].index)
    if (exclude):
        for member in exclude:
            if (member in otherStates):
                otherStates.remove(member)
    if (do_only):
        for member in exclude:
            if (member in otherStates):
                otherStates.remove(member)
        for member in do_only:
            if (member in otherStates):
                otherStates.remove(member)

    showstates = np.minimum(showstates, len(otherStates))
    otherStatesNames = otherStates
    otherStatesNames = [w.replace('-None', '') for w in otherStates]

    for state in otherStatesNames:
        state.replace("-None", "")
    if not silent:
        print(otherStates)
    if (do_only):
        #prediction_states = list(sizes[sizes.index.isin(do_only)].index)
        prediction_states = do_only
        if not silent:
            print(prediction_states)
    else:
        prediction_states = list(sizes[(sizes > low_thresh)
                                       & (sizes <= threshold)].index)

    for state in prediction_states:
        all_rows = list.copy(otherStates)
        all_rows.append(state)
        if not mRSC:
            if random_distribution:
                trainDF = df + random_distribution(df.shape)
                trainDF = trainDF.iloc[:low_thresh, :]
            else:
                trainDF = df.iloc[:low_thresh, :]
        else:
            num_dimensions = len(lambdas)
            trainDF = pd.DataFrame()
            length_one_dimension = list_of_dfs[0].shape[0]
            for i in range(num_dimensions):
                trainDF = pd.concat([
                    trainDF, lambdas[i] * list_of_dfs[i].iloc[:low_thresh, :]
                ],
                                    axis=0)
        if not silent:
            print(trainDF.shape)
        testDF = df.iloc[low_thresh + 1:threshold, :]
        rscModel = RobustSyntheticControl(state,
                                          singVals,
                                          len(trainDF),
                                          probObservation=1.0,
                                          modelType='svd',
                                          svdMethod='numpy',
                                          otherSeriesKeysArray=otherStates)
        rscModel.fit(trainDF)
        denoisedDF = rscModel.model.denoisedDF()
        predictions = []

        predictions = np.dot(testDF[otherStates].values,
                             rscModel.model.weights)
        predictions_noisy = np.dot(testDF[otherStates].values,
                                   rscModel.model.weights)
        x_actual = df[state].index  #range(sizes[state])
        actual = df[state]  #df.iloc[:sizes[state],:][state]

        if (svdSpectrum):
            (U, s, Vh) = np.linalg.svd((trainDF[all_rows]) -
                                       np.mean(trainDF[all_rows]))
            s2 = np.power(s, 2)
            plt.figure(figsize=(8, 6))
            plt.plot(s2)
            plt.grid()
            plt.xlabel("Ordered Singular Values", fontsize=FONTSIZE)
            plt.ylabel("Energy", fontsize=FONTSIZE)
            plt.title("Singular Value Spectrum", fontsize=FONTSIZE)
            plt.show()
        x_predictions = df.index[
            low_thresh:low_thresh +
            len(predictions)]  #range(low_thresh,low_thresh+len(predictions))
        model_fit = np.dot(trainDF[otherStates][:], rscModel.model.weights)
        error = mse(actual[:low_thresh], model_fit)
        if not silent:
            print(state, error)
        # if showPlots:
        #     plt.figure(figsize=(16,6))
        ind = np.argpartition(rscModel.model.weights,
                              -showstates)[-showstates:]
        topstates = [otherStates[i] for i in ind]
        if showDonors:
            axes[0].barh(otherStates,
                         rscModel.model.weights /
                         np.max(rscModel.model.weights),
                         color=list('rgbkymc'))
            axes[0].set_title("Normalized weights for " +
                              str(state).replace("-None", ""),
                              fontsize=FONTSIZE)
        ax = axes[-1] if showDonors else axes
        if (ylimit):
            ax.set_ylim(ylimit)
        if (logy):
            ax.set_yscale('log')
        if (showPlots):
            ax.plot(x_actual,
                    actual,
                    label='Actuals',
                    color='k',
                    linestyle='-')
            ax.plot(x_predictions,
                    predictions,
                    label='Predictions',
                    color='r',
                    linestyle='--')
            ax.plot(df.index[:low_thresh],
                    model_fit,
                    label='Fitted model',
                    color='g',
                    linestyle=':')
            ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))

            ax.axvline(x=df.index[low_thresh - 1],
                       color='k',
                       linestyle='--',
                       linewidth=4)
            ax.grid()
            if showDonors:
                ax.set_title(title_text + " for " +
                             str(state).replace("-None", ""),
                             fontsize=FONTSIZE)
                ax.set_xlabel("Days since Intervention", fontsize=FONTSIZE)
                ax.set_ylabel(yaxis, fontsize=FONTSIZE)
                ax.legend(['Actuals', 'Predictions', 'Fitted Model'],
                          fontsize=FONTSIZE)
            else:
                ax.tick_params(axis='both', which='major', labelsize=FONTSIZE)
                ax.set_title(title_text + " for " +
                             str(state).replace("-None", ""),
                             fontsize=FONTSIZE)
                ax.set_xlabel("Days since Intervention", fontsize=FONTSIZE)
                ax.set_ylabel(yaxis, fontsize=FONTSIZE)
                ax.legend(['Actuals', 'Predictions', 'Fitted Model'],
                          fontsize=FONTSIZE)
            if (savePlots):
                plt.savefig("../Figures/COVID/" + state + ".png")
            if (animation):
                animation.snap()

                #pred_plot.remove()

            elif (showPlots):
                plt.show()
    if (error < error_thresh):
        return (dict(zip(otherStates, rscModel.model.weights)))
    else:
        print(state, error)
        return (dict(
            zip(otherStates, -50 * np.ones(len(rscModel.model.weights)))))
def synth_control_predictions(list_of_dfs,
                              threshold,
                              low_thresh,
                              title_text,
                              singVals=2,
                              savePlots=False,
                              ylimit=[],
                              xlimit=[],
                              logy=False,
                              exclude=[],
                              svdSpectrum=False,
                              showDonors=True,
                              do_only=[],
                              showstates=4,
                              animation=[],
                              figure=None,
                              axes=None,
                              donorPool=[],
                              silent=True,
                              showPlots=True,
                              mRSC=False,
                              lambdas=[1],
                              error_thresh=1,
                              yaxis='Cases',
                              FONTSIZE=20,
                              tick_spacing=30,
                              random_distribution=None,
                              check_nan=0,
                              return_permutation_distribution=False,
                              intervention_date_x_ticks=None):

    df = list_of_dfs[0]

    if (donorPool):
        otherStates = donorPool.copy()
    else:
        sizes = df.apply(pd.Series.last_valid_index)
        sizes = sizes.fillna(0).astype(int)
        otherStates = list(sizes[sizes > threshold].index)
    if (exclude):
        for member in exclude:
            if (member in otherStates):
                otherStates.remove(member)
    if (do_only):
        for member in exclude:
            if (member in otherStates):
                otherStates.remove(member)
        for member in do_only:
            if (member in otherStates):
                otherStates.remove(member)

    showstates = np.minimum(showstates, len(otherStates))
    otherStatesNames = otherStates
    otherStatesNames = [w.replace('-None', '') for w in otherStates]

    for state in otherStatesNames:
        state.replace("-None", "")
    if not silent:
        print(otherStates)
    if (do_only):
        #prediction_states = list(sizes[sizes.index.isin(do_only)].index)
        prediction_states = do_only
        if not silent:
            print(prediction_states)
    else:
        prediction_states = list(sizes[(sizes > low_thresh)
                                       & (sizes <= threshold)].index)

    if check_nan:
        start = max(df[state].first_valid_index()
                    for state in prediction_states)
        if low_thresh - start > check_nan:
            start = low_thresh - check_nan
        df = df.iloc[start:].reset_index(drop=True)
        list_of_dfs = [
            df.iloc[start:].reset_index(drop=True) for df in list_of_dfs
        ]
        low_thresh -= start
        otherStates = [
            state for state in otherStates
            if df[state].first_valid_index() == 0
        ]
        print('final donorpool: ', otherStates)

    for state in prediction_states:
        all_rows = list.copy(otherStates)
        all_rows.append(state)
        if not mRSC:
            if random_distribution:
                trainDF = df + random_distribution(df.shape)
                trainDF = trainDF.iloc[:low_thresh, :]
            else:
                trainDF = df.iloc[:low_thresh, :]
        else:
            num_dimensions = len(lambdas)
            trainDF = pd.DataFrame()
            length_one_dimension = list_of_dfs[0].shape[0]
            for i in range(num_dimensions):
                trainDF = pd.concat([
                    trainDF, lambdas[i] * list_of_dfs[i].iloc[:low_thresh, :]
                ],
                                    axis=0)
        if not silent:
            print(trainDF.shape)
        testDF = df.iloc[low_thresh + 1:threshold, :]
        rscModel = RobustSyntheticControl(state,
                                          singVals,
                                          len(trainDF),
                                          probObservation=1.0,
                                          modelType='svd',
                                          svdMethod='numpy',
                                          otherSeriesKeysArray=otherStates)
        rscModel.fit(trainDF)
        denoisedDF = rscModel.model.denoisedDF()
        predictions = []

        predictions = np.dot(testDF[otherStates].fillna(0).values,
                             rscModel.model.weights)
        predictions_noisy = np.dot(testDF[otherStates].fillna(0).values,
                                   rscModel.model.weights)

        predictions[predictions < 0] = 0
        x_actual = df[state].index  #range(sizes[state])
        actual = df[state]  #df.iloc[:sizes[state],:][state]

        if (svdSpectrum):
            (U, s, Vh) = np.linalg.svd((trainDF[all_rows]) -
                                       np.mean(trainDF[all_rows]))
            s2 = np.power(s, 2)
            plt.figure(figsize=(8, 6))
            plt.plot(s2)
            plt.grid()
            plt.xlabel("Ordered Singular Values", fontsize=FONTSIZE)
            plt.ylabel("Energy", fontsize=FONTSIZE)
            plt.title("Singular Value Spectrum", fontsize=FONTSIZE)
            plt.show()
        x_predictions = df.index[
            low_thresh:low_thresh +
            len(predictions)]  #range(low_thresh,low_thresh+len(predictions))
        model_fit = np.dot(trainDF[otherStates][:].fillna(0),
                           rscModel.model.weights)

        model_fit[model_fit < 0] = 0
        error = mse(actual[:low_thresh], model_fit)
        if not silent:
            print(state, error)
        # if showPlots:
        #     plt.figure(figsize=(16,6))
        ind = np.argpartition(rscModel.model.weights,
                              -showstates)[-showstates:]
        topstates = [otherStates[i] for i in ind]
        if showDonors:
            axes[0].barh(otherStates,
                         rscModel.model.weights /
                         np.max(rscModel.model.weights),
                         color=list('rgbkymc'))
            axes[0].set_title("Normalized weights for " +
                              str(state).replace("-None", ""),
                              fontsize=FONTSIZE)
            axes[0].tick_params(axis='both', which='major', labelsize=FONTSIZE)
        ax = axes[-1] if showDonors else axes
        if (ylimit):
            ax.set_ylim(ylimit)
        if (xlimit):
            ax.set_xlim(xlimit)
        if (logy):
            ax.set_yscale('log')
        if (showPlots):
            #if not
            ax.plot(x_actual,
                    actual,
                    label='Actuals',
                    color='k',
                    linestyle='-')
            ax.plot(x_predictions,
                    predictions,
                    label='Predictions',
                    color='r',
                    linestyle='--')
            ax.plot(df.index[:low_thresh],
                    model_fit,
                    label='Fitted model',
                    color='g',
                    linestyle=':')
            ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))

            ax.axvline(x=df.index[low_thresh - 1],
                       color='k',
                       linestyle='--',
                       linewidth=4)
            ax.grid()
            ax.tick_params(axis='both', which='major', labelsize=FONTSIZE)
            if title_text:
                ax.set_title(title_text + " for " +
                             str(state).replace("-None", ""),
                             fontsize=FONTSIZE)
            ax.set_xlabel("Days since intervention", fontsize=FONTSIZE)
            ax.set_ylabel(yaxis, fontsize=FONTSIZE)
            ax.legend(['Actuals', 'Predictions', 'Fitted Model'],
                      fontsize=FONTSIZE)

            #pred_plot.remove()
            #plt.show()
            figure.canvas.draw()

            if intervention_date_x_ticks:
                labels = [item.get_text() for item in ax.get_xticklabels()]
                x_labels = []
                ts = (pd.to_datetime(intervention_date_x_ticks[state]))
                #ts = pd.to_datetime(str(date))
                for label in labels:
                    tmp_date = ts + datetime.timedelta(days=int(label))
                    x_labels.append(tmp_date.strftime('%Y-%m-%d'))
                ax.set_xlabel("Date", fontsize=FONTSIZE)

                #print(x_labels)
                #int_date = (ts.strftime('%Y-%m-%d'))
                #labels = list(df.index.values)
                ax.set_xticklabels(x_labels, rotation=45)

            if (savePlots):
                plt.savefig("../Figures/COVID/" + state + '.pdf',
                            bbox_inches='tight')
            if (animation):
                animation.snap()
            else:
                plt.show()

    if return_permutation_distribution:
        # sklearn MSE is different from our MSE defined above; I'm not sure what our MSE is intended to represent, as I have never seen MSE defined in that way.
        def find_ri(actual, model_fit, predictions):
            return mean_squared_error(
                actual[len(model_fit):len(model_fit) + len(predictions)],
                predictions) / mean_squared_error(actual[:len(model_fit)],
                                                  model_fit)

        out_dict = dict()

        out_dict[state] = find_ri(
            actual, model_fit,
            predictions)  # this only works when prediction_states has length 1

        for state in otherStates:
            donorPool = otherStates.copy()
            donorPool.remove(state)
            rscModel = RobustSyntheticControl(state,
                                              singVals,
                                              len(trainDF),
                                              probObservation=1.0,
                                              modelType='svd',
                                              svdMethod='numpy',
                                              otherSeriesKeysArray=donorPool)
            rscModel.fit(trainDF)

            actual = df[state].fillna(0)
            model_fit = np.dot(trainDF[donorPool].fillna(0),
                               rscModel.model.weights)
            predictions = np.dot(testDF[donorPool].fillna(0),
                                 rscModel.model.weights)

            out_dict[state] = find_ri(actual, model_fit, predictions)

        return out_dict

    if (error < error_thresh):
        return (dict(zip(otherStates, rscModel.model.weights)))
    else:
        print(state, error)
        return (dict(
            zip(otherStates, -50 * np.ones(len(rscModel.model.weights)))))