def nf_target(params, data, likelihood_min = 1e-10):
    return np.sum(np.maximum(np.log(batch_fptd(data[:, 0] * data[:, 1] * (- 1),
                                               params[0],
                                               params[1] * 2, 
                                               params[2],
                                               params[3])),
                                               np.log(likelihood_min)))
def make_fptd_data(data=[],
                   params=[],
                   metadata=[],
                   keep_file=0,
                   n_kde=100,
                   n_unif_up=100,
                   n_unif_down=100,
                   idx=0):
    out = np.zeros((n_kde + n_unif_up + n_unif_down, 3))
    tmp_kde = kde_class.logkde((data[:, 0], data[:, 1], metadata))

    # Get kde part
    samples_kde = tmp_kde.kde_sample(n_samples=n_kde)
    out[:n_kde, 0] = samples_kde[0].ravel()
    out[:n_kde, 1] = samples_kde[1].ravel()

    # If we have 4 parameters we know we have the ddm --> use default sdv = 0
    if len(params) == 4:
        out[:n_kde, 2] = np.log(
            batch_fptd(out[:n_kde, 0] * out[:n_kde, 1] * (-1), params[0],
                       params[1] * 2, params[2], params[3]))

    # If we have 5 parameters but analytic we know we need to use the ddm_sdv --> supply sdv value to batch_fptd
    if len(params) == 5:
        out[:n_kde, 2] = np.log(
            batch_fptd(out[:n_kde, 0] * out[:n_kde, 1] * (-1), params[0],
                       params[1] * 2, params[2], params[3], params[4]))

    # Get positive uniform part:
    choice_tmp = np.random.choice(metadata['possible_choices'], size=n_unif_up)

    if metadata['max_t'] < 100:
        rt_tmp = np.random.uniform(low=0.0001,
                                   high=metadata['max_t'],
                                   size=n_unif_up)
    else:
        rt_tmp = np.random.uniform(low=0.0001, high=100, size=n_unif_up)

    likelihoods_unif = tmp_kde.kde_eval(data=(rt_tmp, choice_tmp)).ravel()

    out[n_kde:(n_kde + n_unif_up), 0] = rt_tmp
    out[n_kde:(n_kde + n_unif_up), 1] = choice_tmp

    # If we have 4 parameters we know we have the ddm --> use default sdv = 0
    if len(params) == 4:
        out[n_kde:(n_kde + n_unif_up), 2] = np.log(
            batch_fptd(
                out[n_kde:(n_kde + n_unif_up), 0] *
                out[n_kde:(n_kde + n_unif_up), 1] * (-1), params[0],
                params[1] * 2, params[2], params[3]))

    # If we have 5 parameters but analytic we know we need to use the ddm_sdv --> supply sdv value to batch_fptd
    if len(params) == 5:
        out[n_kde:(n_kde + n_unif_up), 2] = np.log(
            batch_fptd(
                out[n_kde:(n_kde + n_unif_up), 0] *
                out[n_kde:(n_kde + n_unif_up), 1] * (-1), params[0],
                params[1] * 2, params[2], params[3], params[4]))

    # Get negative uniform part:
    choice_tmp = np.random.choice(metadata['possible_choices'],
                                  size=n_unif_down)

    rt_tmp = np.random.uniform(low=-1.0, high=0.0001, size=n_unif_down)

    out[(n_kde + n_unif_up):, 0] = rt_tmp
    out[(n_kde + n_unif_up):, 1] = choice_tmp
    out[(n_kde + n_unif_up):, 2] = -66.77497

    if idx % 10 == 0:
        print(idx)

    return out.astype(np.float)
def kde_vs_mlp_likelihoods(ax_titles=[],
                           title='Likelihoods KDE - MLP',
                           network_dir='',
                           x_labels=[],
                           parameter_matrix=[],
                           cols=3,
                           model='angle',
                           data_signature='',
                           n_samples=10,
                           nreps=10,
                           save=True,
                           show=False,
                           machine='home',
                           method='mlp',
                           traindatanalytic=0,
                           plot_format='svg'):

    mpl.rcParams['text.usetex'] = True
    #matplotlib.rcParams['pdf.fonttype'] = 42
    mpl.rcParams['svg.fonttype'] = 'none'

    # Initialize rows and graph parameters
    rows = int(np.ceil(len(ax_titles) / cols))
    sns.set(style="white", palette="muted", color_codes=True, font_scale=2)

    fig, ax = plt.subplots(rows,
                           cols,
                           figsize=(10, 10),
                           sharex=True,
                           sharey=False)

    fig.suptitle(title + ': ' + model.upper().replace('_', '-'), fontsize=40)
    sns.despine(right=True)

    # Data template
    plot_data = np.zeros((4000, 2))
    plot_data[:, 0] = np.concatenate(([i * 0.0025 for i in range(2000, 0, -1)],
                                      [i * 0.0025 for i in range(1, 2001, 1)]))
    plot_data[:, 1] = np.concatenate((np.repeat(-1, 2000), np.repeat(1, 2000)))

    # Load Keras model and initialize batch container
    keras_model = keras.models.load_model(network_dir + 'model_final.h5')
    keras_input_batch = np.zeros((4000, parameter_matrix.shape[1] + 2))
    keras_input_batch[:, parameter_matrix.shape[1]:] = plot_data

    for i in range(len(ax_titles)):

        print('Making Plot: ', i)

        row_tmp = int(np.floor(i / cols))
        col_tmp = i - (cols * row_tmp)

        # Get predictions from keras model
        keras_input_batch[:, :parameter_matrix.shape[1]] = parameter_matrix[
            i, :]
        ll_out_keras = keras_model.predict(keras_input_batch, batch_size=100)

        # Get prediction from navarro if traindatanalytic = 1
        if traindatanalytic:
            ll_out_gt = cdw.batch_fptd(plot_data[:, 0] * plot_data[:, 1],
                                       v=parameter_matrix[i, 0],
                                       a=parameter_matrix[i, 1],
                                       w=parameter_matrix[i, 2],
                                       ndt=parameter_matrix[i, 3])

            sns.lineplot(plot_data[:, 0] * plot_data[:, 1],
                         ll_out_gt,
                         color='black',
                         alpha=0.5,
                         label='TRUE',
                         ax=ax[row_tmp, col_tmp])

        # Get predictions from simulations /kde

        if not traindatanalytic:
            for j in range(nreps):
                if model == 'ddm' or model == 'ddm_analytic':
                    out = cds.ddm_flexbound(v=parameter_matrix[i, 0],
                                            a=parameter_matrix[i, 1],
                                            w=parameter_matrix[i, 2],
                                            ndt=parameter_matrix[i, 3],
                                            s=1,
                                            delta_t=0.001,
                                            max_t=20,
                                            n_samples=n_samples,
                                            print_info=False,
                                            boundary_fun=bf.constant,
                                            boundary_multiplicative=True,
                                            boundary_params={})

                if model == 'ddm_sdv':
                    out = cds.ddm_sdv(v=parameter_matrix[i, 0],
                                      a=parameter_matrix[i, 1],
                                      w=parameter_matrix[i, 2],
                                      ndt=parameter_matrix[i, 3],
                                      sdv=parameter_matrix[i, 4],
                                      s=1,
                                      delta_t=0.001,
                                      max_t=20,
                                      n_samples=n_samples,
                                      print_info=False,
                                      boundary_fun=bf.constant,
                                      boundary_multiplicative=True,
                                      boundary_params={})

                if model == 'full_ddm' or model == 'full_ddm2':
                    out = cds.full_ddm(v=parameter_matrix[i, 0],
                                       a=parameter_matrix[i, 1],
                                       w=parameter_matrix[i, 2],
                                       ndt=parameter_matrix[i, 3],
                                       dw=parameter_matrix[i, 4],
                                       sdv=parameter_matrix[i, 5],
                                       dndt=parameter_matrix[i, 6],
                                       s=1,
                                       delta_t=0.001,
                                       max_t=20,
                                       n_samples=n_samples,
                                       print_info=False,
                                       boundary_fun=bf.constant,
                                       boundary_multiplicative=True,
                                       boundary_params={})

                if model == 'angle' or model == 'angle2':
                    out = cds.ddm_flexbound(
                        v=parameter_matrix[i, 0],
                        a=parameter_matrix[i, 1],
                        w=parameter_matrix[i, 2],
                        ndt=parameter_matrix[i, 3],
                        s=1,
                        delta_t=0.001,
                        max_t=20,
                        n_samples=n_samples,
                        print_info=False,
                        boundary_fun=bf.angle,
                        boundary_multiplicative=False,
                        boundary_params={'theta': parameter_matrix[i, 4]})

                if model == 'weibull_cdf' or model == 'weibull_cdf2':
                    out = cds.ddm_flexbound(v=parameter_matrix[i, 0],
                                            a=parameter_matrix[i, 1],
                                            w=parameter_matrix[i, 2],
                                            ndt=parameter_matrix[i, 3],
                                            s=1,
                                            delta_t=0.001,
                                            max_t=20,
                                            n_samples=n_samples,
                                            print_info=False,
                                            boundary_fun=bf.weibull_cdf,
                                            boundary_multiplicative=True,
                                            boundary_params={
                                                'alpha': parameter_matrix[i,
                                                                          4],
                                                'beta': parameter_matrix[i, 5]
                                            })

                if model == 'levy':
                    out = cds.levy_flexbound(v=parameter_matrix[i, 0],
                                             a=parameter_matrix[i, 1],
                                             w=parameter_matrix[i, 2],
                                             alpha_diff=parameter_matrix[i, 3],
                                             ndt=parameter_matrix[i, 4],
                                             s=1,
                                             delta_t=0.001,
                                             max_t=20,
                                             n_samples=n_samples,
                                             print_info=False,
                                             boundary_fun=bf.constant,
                                             boundary_multiplicative=True,
                                             boundary_params={})

                if model == 'ornstein':
                    out = cds.ornstein_uhlenbeck(v=parameter_matrix[i, 0],
                                                 a=parameter_matrix[i, 1],
                                                 w=parameter_matrix[i, 2],
                                                 g=parameter_matrix[i, 3],
                                                 ndt=parameter_matrix[i, 4],
                                                 s=1,
                                                 delta_t=0.001,
                                                 max_t=20,
                                                 n_samples=n_samples,
                                                 print_info=False,
                                                 boundary_fun=bf.constant,
                                                 boundary_multiplicative=True,
                                                 boundary_params={})

                mykde = kdec.logkde((out[0], out[1], out[2]))
                ll_out_gt = mykde.kde_eval((plot_data[:, 0], plot_data[:, 1]))

                # Plot kde predictions
                if j == 0:
                    sns.lineplot(plot_data[:, 0] * plot_data[:, 1],
                                 np.exp(ll_out_gt),
                                 color='black',
                                 alpha=0.5,
                                 label='KDE',
                                 ax=ax[row_tmp, col_tmp])
                elif j > 0:
                    sns.lineplot(plot_data[:, 0] * plot_data[:, 1],
                                 np.exp(ll_out_gt),
                                 color='black',
                                 alpha=0.5,
                                 ax=ax[row_tmp, col_tmp])

            # Plot keras predictions
            sns.lineplot(plot_data[:, 0] * plot_data[:, 1],
                         np.exp(ll_out_keras[:, 0]),
                         color='green',
                         label='MLP',
                         alpha=1,
                         ax=ax[row_tmp, col_tmp])

        # Legend adjustments
        if row_tmp == 0 and col_tmp == 0:
            ax[row_tmp, col_tmp].legend(loc='upper left',
                                        fancybox=True,
                                        shadow=True,
                                        fontsize=12)
        else:
            ax[row_tmp, col_tmp].legend().set_visible(False)

        if row_tmp == rows - 1:
            ax[row_tmp, col_tmp].set_xlabel('rt', fontsize=24)
        else:
            ax[row_tmp, col_tmp].tick_params(color='white')

        if col_tmp == 0:
            ax[row_tmp, col_tmp].set_ylabel('likelihood', fontsize=24)

        ax[row_tmp, col_tmp].set_title(ax_titles[i], fontsize=20)
        ax[row_tmp, col_tmp].tick_params(axis='y', size=16)
        ax[row_tmp, col_tmp].tick_params(axis='x', size=16)

    for i in range(len(ax_titles), rows * cols, 1):
        row_tmp = int(np.floor(i / cols))
        col_tmp = i - (cols * row_tmp)
        ax[row_tmp, col_tmp].axis('off')

    if save == True:
        if machine == 'home':
            fig_dir = "/users/afengler/OneDrive/git_repos/nn_likelihoods/figures/" + method + "/likelihoods/"
            if not os.path.isdir(fig_dir):
                os.mkdir(fig_dir)

        figure_name = 'mlp_vs_kde_likelihood_'
        plt.subplots_adjust(top=0.9)
        plt.subplots_adjust(hspace=0.3, wspace=0.3)

        if traindatanalytic == 1:
            if plot_format == 'svg':
                plt.savefig(fig_dir + '/' + figure_name + model +
                            data_signature + '_' + train_data_type + '.svg',
                            format='svg',
                            transparent=True,
                            frameon=False)
            if plot_format == 'png':
                plt.savefig(fig_dir + '/' + figure_name + model + '_analytic' +
                            '.png',
                            dpi=300)  #, bbox_inches = 'tight')

        else:
            if plot_format == 'svg':
                plt.savefig(fig_dir + '/' + figure_name + model + '_kde' +
                            '.svg',
                            format='svg',
                            transparent=True,
                            frameon=False)

            if plot_format == 'png':
                plt.savefig(fig_dir + '/' + figure_name + model + '_kde' +
                            '.png',
                            dpi=300)  #, bbox_inches = 'tight')

    if show:
        return plt.show()
    else:
        plt.close()
        return 'finished'