def grangerMatrix(matrix, mem=5):
    # pass a spike train matrix
    n = np.shape(matrix)[1]
    gMat = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            if i != j:
                f = []
                c = matrix[:, (i, j)]
                g = granger(c, maxlag=mem, verbose=False)
                for k in g.keys():
                    f.append(g[k][0]['ssr_ftest'][0])
                gMat[i, j] = max(f)
    return (gMat)
def grangertests(v1, v2, maxlag=3):
    g = granger([*zip(*[v1, v2] / mV)], maxlag, verbose=True)
    print(g)
    return g
def plot_correlations(nc,
                      src,
                      target,
                      s,
                      compare_test='correlation',
                      connectivity=None):
    a, b = src, target
    all_vals = np.concatenate((a.flatten(), b.flatten()))
    lims = [np.min(all_vals), np.max(all_vals)]
    lim_diff = lims[1] - lims[0]
    lim_prct = 0.07
    lim_delta = lim_prct * lim_diff
    lims = [lims[0] - lim_delta, lims[1] + lim_delta]
    fig = plt.figure(figsize=(10, 10))
    gs_factor = 3
    gs = plt.GridSpec(nc * gs_factor + 1,
                      nc * gs_factor + 1,
                      figure=fig,
                      wspace=0,
                      hspace=0)

    normed_map2 = mpl.cm.ScalarMappable(norm=mpl.colors.Normalize(vmin=-1,
                                                                  vmax=1),
                                        cmap=mpl.cm.RdBu)
    normed_map1 = mpl.cm.ScalarMappable(norm=mpl.colors.Normalize(vmin=-1,
                                                                  vmax=1),
                                        cmap=mpl.cm.RdBu)
    for i in range(nc + 1):
        for j in range(nc + 1):
            if i == nc * gs_factor / 2 or j == nc * gs_factor / 2:
                ax = plt.subplot(gs[i, j])
                ax.set_axis_off()
    match_count = 0
    n = 0
    for i, i_row in enumerate(reversed(range(nc))):
        for j in range(nc):
            n += 1
            print('%s Calculator Cycle # %d / %d' % (compare_test, n, nc**2))
            if i_row >= nc / 2:
                ax_row_idx = 1
            else:
                ax_row_idx = 0
            if j >= nc / 2:
                ax_col_idx = 1
            else:
                ax_col_idx = 0

            ax = plt.subplot(
                gs[(i_row * gs_factor + ax_row_idx):((i_row + 1) * gs_factor +
                                                     ax_row_idx),
                   (j * gs_factor + ax_col_idx):((j + 1) * gs_factor +
                                                 ax_col_idx)])
            try:
                w = float(s.w[np.logical_and(s.i == j, s.j == i)]) / 3
            except TypeError:
                w = 0
            if compare_test == 'correlation':
                R = np.corrcoef(a[j].flatten(), b[i].flatten())[1, 0]
                metric = (R, R)
                text2_args = ['\\textbf{R = %.2f}' % metric[1]]
                if (metric[0]**2 > 0.75 and w > 0.1) or (metric[0]**2 < 0.75
                                                         and w < 0.1):
                    match_count += 1
            elif compare_test == 'granger':
                gg = np.stack((b[i], a[j]), 1).T / mV
                G = granger(gg.T, maxlag=20, verbose=False)
                all_ps = []
                for k_gr, v_gr in G.items():
                    all_ps.append(v_gr[0]['ssr_ftest'][0])
                all_ps = np.array(all_ps)
                metric = (np.max(all_ps) / 300, np.max(all_ps))
                text2_args = ['\\textbf{F = %.1f}' % metric[1]]
                if (metric[0] > 0.035
                        and np.abs(w) > 0.1) or (metric[0] < 0.035
                                                 and np.abs(w) < 0.1):
                    match_count += 1
            scatter_kwargs = dict(x=a[j],
                                  y=b[i],
                                  s=3,
                                  facecolors='none',
                                  edgecolors=(0, 0, 0, 0.2),
                                  marker='o',
                                  linewidths=0.5)
            text1_args = ['\\textbf{w = %.2f}' % w]
            text1_kwargs = dict(
                loc='lower right',
                frameon=False,
                pad=0.09,
                prop=dict(color='k', fontsize='xx-small'),
            )

            text2_kwargs = dict(text1_kwargs)
            text2_kwargs['loc'] = 'upper left'

            pt1 = mpl.patches.Polygon(np.array([[1, 1], [0, 0], [1, 0]]),
                                      transform=ax.transAxes,
                                      facecolor=normed_map1.to_rgba(w),
                                      edgecolor='none',
                                      alpha=0.6,
                                      zorder=0.5)
            pt2 = mpl.patches.Polygon(np.array([[1, 1], [0, 0], [0, 1]]),
                                      transform=ax.transAxes,
                                      facecolor=normed_map2.to_rgba(metric[0]),
                                      edgecolor='none',
                                      alpha=0.6,
                                      zorder=0.4)
            ax.add_artist(pt1)
            ax.add_artist(pt2)
            if compare_test == 'correlation':
                ax.scatter(**scatter_kwargs)
            ax.add_artist(AnchoredText(*text1_args, **text1_kwargs))
            ax.add_artist(AnchoredText(*text2_args, **text2_kwargs))
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_xlim(*lims)
            ax.set_ylim(*lims)
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)

    ax = fig.add_subplot(gs[:, :])
    ax.patch.set_alpha(0)
    despine_all(ax)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xlabel('Source Neuron', fontsize='large', labelpad=30)
    ax.set_ylabel('Target Neuron', fontsize='large', labelpad=30)

    gap_size = 1 / (gs_factor * nc + 1)
    low_middle = 0.5 - gap_size / 2
    high_middle = 0.5 + gap_size / 2
    thickness = -0.03
    cset1 = ('xkcd:golden yellow', 'k')
    cset2 = ('k', 'w')

    if connectivity in ('e_to_e', 'e_to_i', 'all'):
        color = cset1
    else:
        color = cset2
    pt = mpl.patches.Rectangle((0, -0.01),
                               width=low_middle,
                               height=thickness,
                               transform=ax.transAxes,
                               facecolor=color[0],
                               edgecolor='k',
                               clip_path=None,
                               clip_on=False)
    ax.text(low_middle / 2,
            -0.01 + thickness / 2,
            '$E$',
            color=color[1],
            fontsize='large',
            verticalalignment='center',
            horizontalalignment='center')
    ax.add_artist(pt)

    if connectivity in ('i_to_e', 'i_to_i', 'all'):
        color = cset1
    else:
        color = cset2
    pt = mpl.patches.Rectangle((high_middle, -0.01),
                               width=low_middle,
                               height=thickness,
                               transform=ax.transAxes,
                               facecolor=color[0],
                               edgecolor='k',
                               clip_path=None,
                               clip_on=False)
    ax.text(low_middle / 2 + high_middle,
            -0.01 + thickness / 2,
            '$I$',
            color=color[1],
            fontsize='large',
            verticalalignment='center',
            horizontalalignment='center')
    ax.add_artist(pt)

    if connectivity in ('i_to_e', 'e_to_e', 'all'):
        color = cset1
    else:
        color = cset2
    pt = mpl.patches.Rectangle((-0.01, 0),
                               width=thickness,
                               height=low_middle,
                               transform=ax.transAxes,
                               facecolor=color[0],
                               edgecolor='k',
                               clip_path=None,
                               clip_on=False)
    ax.text(-0.01 + thickness / 2,
            low_middle / 2,
            '$E$',
            color=color[1],
            fontsize='large',
            verticalalignment='center',
            horizontalalignment='center')
    ax.add_artist(pt)

    if connectivity in ('i_to_i', 'e_to_i', 'all'):
        color = cset1
    else:
        color = cset2
    pt = mpl.patches.Rectangle((-0.01, high_middle),
                               width=thickness,
                               height=low_middle,
                               transform=ax.transAxes,
                               facecolor=color[0],
                               edgecolor='k',
                               clip_path=None,
                               clip_on=False)
    ax.text(-0.01 + thickness / 2,
            low_middle / 2 + high_middle,
            '$I$',
            color=color[1],
            fontsize='large',
            verticalalignment='center',
            horizontalalignment='center')
    ax.add_artist(pt)
    return match_count
Example #4
0
    else:
        for date_range in date_range_tags:
            start_end_date_pair = date_range.split("-")
            start_date = datetime.strptime(start_end_date_pair[0], "%Y%m%d")
            end_date = datetime.strptime(start_end_date_pair[1], "%Y%m%d")
            date_ranges.append((start_date, end_date))

    # slicing by date range
    for date_range in date_ranges:
        start_date = date_range[0]
        end_date = date_range[1]
        partial_df = total_df[start_date:end_date]

        mdata = partial_df[["ret_y", "ret_x"]]

        results = granger(mdata, maxlag=lag)

        print("\n")
        print("\n")
        print("Date Range: ", start_date, " - ", end_date)
        print("###########################################################")

        for index in range(1, lag + 1):
            print("parameters ftest:     ", results[index][0]["params_ftest"])
            print("ssr chi2test:         ", results[index][0]["ssr_chi2test"])
            print("likelihood ratio test:", results[index][0]["lrtest"])
            print("ssr ftest:            ", results[index][0]["ssr_ftest"])
            print("--------------------------------------------------------------")
            print("restricted model:")
            reg = results[index][1][0]
            slopes = reg.params[1:]
Example #5
0
        print('AR:', p, 'MA:', q, 'AIC:', NHS_arima_model.aic)
# for first differenced NHS models searched, AR p=2 MA q=2 is best 
NHS_arima_model_selected = ARIMA(NHS_data['NHS'], order = (2,1,2)).fit()
# fitted parameters of the selected model
print(NHS_arima_model_selected.params)
# look-ahead forecasts needed 

# Which regressors have potential as leading indicators?
# look for relationships across three of the time series
# using the period of overlap for those series

# does time series in second column "cause" time series in first column
print('Granger Tests')
# R form of test: grangertest(ICS~ER, order = 3, data=modeling.mts)
ICS_from_ER =  pd.DataFrame(modeling_mts, columns = ['ICS','ER'])
test = granger(ICS_from_ER, maxlag = 3, addconst=True, verbose=False)
print('ICS_from_ER:',test[3][0]['params_ftest'])

# R form of test: grangertest(ICS~DGO, order = 3, data=modeling.mts)
ICS_from_DGO =  pd.DataFrame(modeling_mts, columns = ['ICS','DGO'])
test = granger(ICS_from_DGO, maxlag = 3, addconst=True, verbose=False)
print('ICS_from_DGO:',test[3][0]['params_ftest'])

# R form of test: grangertest(DGO~ER, order = 3, data=modeling.mts)
DGO_from_ER =  pd.DataFrame(modeling_mts, columns = ['DGO','ER'])
test = granger(DGO_from_ER, maxlag = 3, addconst=True, verbose=False)
print('DGO_from_ER:',test[3][0]['params_ftest'])

# R form of test: grangertest(DGO~ICS, order = 3, data=modeling.mts)
DGO_from_ICS =  pd.DataFrame(modeling_mts, columns = ['DGO','ICS'])
test = granger(DGO_from_ICS, maxlag = 3, addconst=True, verbose=False)
        print('AR:', p, 'MA:', q, 'AIC:', NHS_arima_model.aic)
# for first differenced NHS models searched, AR p=2 MA q=2 is best
NHS_arima_model_selected = ARIMA(NHS_data['NHS'], order=(2, 1, 2)).fit()
# fitted parameters of the selected model
print(NHS_arima_model_selected.params)
# look-ahead forecasts needed

# Which regressors have potential as leading indicators?
# look for relationships across three of the time series
# using the period of overlap for those series

# does time series in second column "cause" time series in first column
print('Granger Tests')
# R form of test: grangertest(ICS~ER, order = 3, data=modeling.mts)
ICS_from_ER = pd.DataFrame(modeling_mts, columns=['ICS', 'ER'])
test = granger(ICS_from_ER, maxlag=3, addconst=True, verbose=False)
print('ICS_from_ER:', test[3][0]['params_ftest'])

# R form of test: grangertest(ICS~DGO, order = 3, data=modeling.mts)
ICS_from_DGO = pd.DataFrame(modeling_mts, columns=['ICS', 'DGO'])
test = granger(ICS_from_DGO, maxlag=3, addconst=True, verbose=False)
print('ICS_from_DGO:', test[3][0]['params_ftest'])

# R form of test: grangertest(DGO~ER, order = 3, data=modeling.mts)
DGO_from_ER = pd.DataFrame(modeling_mts, columns=['DGO', 'ER'])
test = granger(DGO_from_ER, maxlag=3, addconst=True, verbose=False)
print('DGO_from_ER:', test[3][0]['params_ftest'])

# R form of test: grangertest(DGO~ICS, order = 3, data=modeling.mts)
DGO_from_ICS = pd.DataFrame(modeling_mts, columns=['DGO', 'ICS'])
test = granger(DGO_from_ICS, maxlag=3, addconst=True, verbose=False)
Example #7
0
    else:
        for date_range in date_range_tags:
            start_end_date_pair = date_range.split("-")
            start_date = datetime.strptime(start_end_date_pair[0], '%Y%m%d')
            end_date = datetime.strptime(start_end_date_pair[1], '%Y%m%d')
            date_ranges.append((start_date, end_date))
  
    # slicing by date range
    for date_range in date_ranges:
        start_date = date_range[0]
        end_date = date_range[1]
        partial_df = total_df[start_date:end_date]
    
        mdata = partial_df[['ret_y', 'ret_x']] 

        results = granger(mdata, maxlag=lag)

        print("\n")
        print("\n")
        print("Date Range: ", start_date, " - ", end_date)
        print("###########################################################")

        for index in range(1, lag+1): 
            print("parameters ftest:     ", results[index][0]["params_ftest"])
            print("ssr chi2test:         ", results[index][0]["ssr_chi2test"])
            print("likelihood ratio test:", results[index][0]["lrtest"])
            print("ssr ftest:            ", results[index][0]["ssr_ftest"])
            print("--------------------------------------------------------------")
            print("restricted model:")
            reg = results[index][1][0]
            slopes = reg.params[1:]