def update_line_specipic_points(nums, data, axes, to_do, font_size, axis_font):
    """Update the lines in the axes for snapshot of the whole process"""
    colors = LAYERS_COLORS
    x_ticks = [0, 2, 4, 6, 8, 10]
    #Go over all the snapshot
    for i in range(len(nums)):
        num = nums[i]
        #Plot the right layer
        for layer_num in range(data.shape[3]):
            axes[i].scatter(data[0, :, num, layer_num],
                            data[1, :, num, layer_num],
                            color=colors[layer_num],
                            s=105,
                            edgecolors='black',
                            alpha=0.85)
        utils.adjustAxes(axes[i],
                         axis_font=axis_font,
                         title_str='',
                         x_ticks=x_ticks,
                         y_ticks=[],
                         x_lim=None,
                         y_lim=None,
                         set_xlabel=to_do[i][0],
                         set_ylabel=to_do[i][1],
                         x_label='$I(X;T)$',
                         y_label='$I(T;Y)$',
                         set_xlim=True,
                         set_ylim=True,
                         set_ticks=True,
                         label_size=font_size)
def update_line_each_neuron(num,
                            print_loss,
                            Ix,
                            axes,
                            Iy,
                            train_data,
                            accuracy_test,
                            epochs_bins,
                            loss_train_data,
                            loss_test_data,
                            colors,
                            epochsInds,
                            font_size=18,
                            axis_font=16,
                            x_lim=[0, 12.2],
                            y_lim=[0, 1.08],
                            x_ticks=[],
                            y_ticks=[]):
    """Update the figure of the infomration plane for the movie"""
    #Print the line between the points
    axes[0].clear()
    if len(axes) > 1:
        axes[1].clear()
    #Print the points
    for layer_num in range(Ix.shape[2]):
        for net_ind in range(Ix.shape[0]):
            axes[0].scatter(Ix[net_ind, num, layer_num],
                            Iy[net_ind, num, layer_num],
                            color=colors[layer_num],
                            s=35,
                            edgecolors='black',
                            alpha=0.85)
    title_str = 'Information Plane - Epoch number - ' + str(epochsInds[num])
    utils.adjustAxes(axes[0],
                     axis_font,
                     title_str,
                     x_ticks,
                     y_ticks,
                     x_lim,
                     y_lim,
                     set_xlabel=True,
                     set_ylabel=True,
                     x_label='$I(X;T)$',
                     y_label='$I(T;Y)$')
    #Print the loss function and the error
    if len(axes) > 1:
        axes[1].plot(epochsInds[:num],
                     1 - np.mean(accuracy_test[:, :num], axis=0),
                     color='g')
        if print_loss:
            axes[1].plot(epochsInds[:num],
                         np.mean(loss_test_data[:, :num], axis=0),
                         color='y')
        nereast_val = np.searchsorted(epochs_bins,
                                      epochsInds[num],
                                      side='right')
        axes[1].set_xlim([0, epochs_bins[nereast_val]])
        axes[1].legend(('Accuracy', 'Loss Function'), loc='best')
def plot_by_training_samples(I_XT_array, I_TY_array, axes, epochsInds, f,
                             index_i, index_j, size_ind, font_size, y_ticks,
                             x_ticks, colorbar_axis, title_str, axis_font,
                             bar_font, save_name, samples_labels):
    """Print the final epoch of all the diffrenet training samples size """
    max_index = size_ind if size_ind != -1 else I_XT_array.shape[2] - 1
    cmap = plt.get_cmap('gnuplot')
    colors = [cmap(i) for i in np.linspace(0, 1, max_index + 1)]
    #Print the final epoch
    nums_epoch = -1
    #Go over all the samples size and plot them with the right color
    for index_in_range in range(0, max_index):
        XT, TY = [], []
        for layer_index in range(0, I_XT_array.shape[4]):
            XT.append(
                np.mean(I_XT_array[:, -1, index_in_range, nums_epoch,
                                   layer_index],
                        axis=0))
            TY.append(
                np.mean(I_TY_array[:, -1, index_in_range, nums_epoch,
                                   layer_index],
                        axis=0))
        axes[index_i, index_j].plot(XT,
                                    TY,
                                    marker='o',
                                    linestyle='-',
                                    markersize=12,
                                    markeredgewidth=0.2,
                                    linewidth=0.5,
                                    color=colors[index_in_range])
    utils.adjustAxes(axes[index_i, index_j],
                     axis_font=axis_font,
                     title_str=title_str,
                     x_ticks=x_ticks,
                     y_ticks=y_ticks,
                     x_lim=None,
                     y_lim=None,
                     set_xlabel=index_i == axes.shape[0] - 1,
                     set_ylabel=index_j == 0,
                     x_label='$I(X;T)$',
                     y_label='$I(T;Y)$',
                     set_xlim=True,
                     set_ylim=True,
                     set_ticks=True,
                     label_size=font_size)
    #Create color bar and save it
    if index_i == axes.shape[0] - 1 and index_j == axes.shape[1] - 1:
        utils.create_color_bar(f,
                               cmap,
                               colorbar_axis,
                               bar_font,
                               epochsInds,
                               title='Training Data')
        f.savefig(save_name + '.png', dpi=150, format='png')
def plot_alphas(str_name, save_name='dist'):
    data_array = utils.get_data(str_name)
    params = np.squeeze(np.array(data_array['information']))
    I_XT_array = np.squeeze(np.array(extract_array(params, 'local_IXT')))
    """"
    for i in range(I_XT_array.shape[2]):
        f1, axes1 = plt.subplots(1, 1)

        axes1.plot(I_XT_array[:,:,i])
    plt.show()
    return
    """
    I_XT_array_var = np.squeeze(
        np.array(extract_array(params, 'IXT_vartional')))
    I_TY_array_var = np.squeeze(
        np.array(extract_array(params, 'ITY_vartional')))

    I_TY_array = np.squeeze(np.array(extract_array(params, 'local_ITY')))
    """
    f1, axes1 = plt.subplots(1, 1)
    #axes1.plot(I_XT_array,I_TY_array)
    f1, axes2 = plt.subplots(1, 1)

    axes1.plot(I_XT_array ,I_TY_array_var)
    axes2.plot(I_XT_array ,I_TY_array)
    f1, axes1 = plt.subplots(1, 1)
    axes1.plot(I_TY_array, I_TY_array_var)
    axes1.plot([0, 1.1], [0, 1.1], transform=axes1.transAxes)
    #axes1.set_title('Sigmma=' + str(sigmas[i]))
    axes1.set_ylim([0, 1.1])
    axes1.set_xlim([0, 1.1])
    plt.show()
    return
    """
    #for i in range()
    sigmas = np.linspace(0, 0.3, 20)

    for i in range(0, 20):
        print(i, sigmas[i])
        f1, axes1 = plt.subplots(1, 1)
        axes1.plot(I_XT_array, I_XT_array_var[:, :, i], linewidth=5)
        axes1.plot([0, 15.1], [0, 15.1], transform=axes1.transAxes)
        axes1.set_title('Sigmma=' + str(sigmas[i]))
        axes1.set_ylim([0, 15.1])
        axes1.set_xlim([0, 15.1])
    plt.show()
    return
    epochs_s = data_array['params']['epochsInds']
    f, axes = plt.subplots(1, 1)
    #epochs_s = []
    colors = LAYERS_COLORS
    linestyles = ['--', '-.', '-', '', ' ', ':', '']
    epochs_s = [0, -1]
    for j in epochs_s:
        for i in range(0, I_XT_array.shape[1]):

            axes.plot(sigmas,
                      I_XT_array_var[j, i, :],
                      color=colors[i],
                      linestyle=linestyles[j],
                      label='Layer-' + str(i) + ' Epoch - ' + str(epochs_s[j]))
    title_str = 'I(X;T) for different layers as function of $\sigma$ (The width of the gaussian)'
    x_label = '$\sigma$'
    y_label = '$I(X;T)$'
    x_lim = [0, 3]
    utils.adjustAxes(axes,
                     axis_font=20,
                     title_str=title_str,
                     x_ticks=[],
                     y_ticks=[],
                     x_lim=x_lim,
                     y_lim=None,
                     set_xlabel=True,
                     set_ylabel=True,
                     x_label=x_label,
                     y_label=y_label,
                     set_xlim=True,
                     set_ylim=False,
                     set_ticks=False,
                     label_size=20,
                     set_yscale=False,
                     set_xscale=False,
                     yscale=None,
                     xscale=None,
                     ytick_labels='',
                     genreal_scaling=False)
    axes.legend()
    plt.show()
def update_axes(axes,
                xlabel,
                ylabel,
                xlim,
                ylim,
                title,
                xscale,
                yscale,
                x_ticks,
                y_ticks,
                p_0,
                p_1,
                font_size=30,
                axis_font=25,
                legend_font=16):
    """adjust the axes to the ight scale/ticks and labels"""
    categories = 6 * ['']
    labels = [
        '$10^{-5}$', '$10^{-4}$', '$10^{-3}$', '$10^{-2}$', '$10^{-1}$',
        '$10^0$', '$10^1$'
    ]
    #The legents of the mean and the std
    leg1 = plt.legend(p_0,
                      categories,
                      title=r'$\|Mean\left(\nabla{W_i}\right)\|$',
                      loc='best',
                      fontsize=legend_font,
                      markerfirst=False,
                      handlelength=5)
    leg2 = plt.legend(p_1,
                      categories,
                      title=r'$STD\left(\nabla{W_i}\right)$',
                      loc='best',
                      fontsize=legend_font,
                      markerfirst=False,
                      handlelength=5)
    leg1.get_title().set_fontsize('21')  # legend 'Title' fontsize
    leg2.get_title().set_fontsize('21')  # legend 'Title' fontsize
    plt.gca().add_artist(leg1)
    plt.gca().add_artist(leg2)
    utils.adjustAxes(axes,
                     axis_font=20,
                     title_str='',
                     x_ticks=x_ticks,
                     y_ticks=y_ticks,
                     x_lim=xlim,
                     y_lim=ylim,
                     set_xlabel=True,
                     set_ylabel=True,
                     x_label=xlabel,
                     y_label=ylabel,
                     set_xlim=True,
                     set_ylim=True,
                     set_ticks=True,
                     label_size=font_size,
                     set_yscale=True,
                     set_xscale=True,
                     yscale=yscale,
                     xscale=xscale,
                     ytick_labels=labels,
                     genreal_scaling=True)
def update_line(num,
                print_loss,
                data,
                axes,
                epochsInds,
                test_error,
                test_data,
                epochs_bins,
                loss_train_data,
                loss_test_data,
                colors,
                font_size=18,
                axis_font=16,
                x_lim=[0, 12.2],
                y_lim=[0, 1.08],
                x_ticks=[],
                y_ticks=[]):
    """Update the figure of the infomration plane for the movie"""
    #Print the line between the points
    cmap = ListedColormap(LAYERS_COLORS)
    segs = []
    for i in range(0, data.shape[1]):
        x = data[0, i, num, :]
        y = data[1, i, num, :]
        points = np.array([x, y]).T.reshape(-1, 1, 2)
        segs.append(np.concatenate([points[:-1], points[1:]], axis=1))
    segs = np.array(segs).reshape(-1, 2, 2)
    axes[0].clear()
    if len(axes) > 1:
        axes[1].clear()
    lc = LineCollection(segs,
                        cmap=cmap,
                        linestyles='solid',
                        linewidths=0.3,
                        alpha=0.6)
    lc.set_array(np.arange(0, 5))
    #Print the points
    for layer_num in range(data.shape[3]):
        axes[0].scatter(data[0, :, num, layer_num],
                        data[1, :, num, layer_num],
                        color=colors[layer_num],
                        s=35,
                        edgecolors='black',
                        alpha=0.85)
    axes[1].plot(epochsInds[:num],
                 1 - np.mean(test_error[:, :num], axis=0),
                 color='r')

    title_str = 'Information Plane - Epoch number - ' + str(epochsInds[num])
    utils.adjustAxes(axes[0],
                     axis_font,
                     title_str,
                     x_ticks,
                     y_ticks,
                     x_lim,
                     y_lim,
                     set_xlabel=True,
                     set_ylabel=True,
                     x_label='$I(X;T)$',
                     y_label='$I(T;Y)$')
    title_str = 'Precision as function of the epochs'
    utils.adjustAxes(axes[1],
                     axis_font,
                     title_str,
                     x_ticks,
                     y_ticks,
                     x_lim,
                     y_lim,
                     set_xlabel=True,
                     set_ylabel=True,
                     x_label='# Epochs',
                     y_label='Precision')
def plot_all_epochs(I_XT_array,
                    I_TY_array,
                    axes,
                    epochsInds,
                    f,
                    index_i,
                    index_j,
                    size_ind,
                    font_size,
                    y_ticks,
                    x_ticks,
                    colorbar_axis,
                    title_str,
                    axis_font,
                    bar_font,
                    save_name,
                    epochFlag,
                    plot_error=True,
                    index_to_emphasis=1000):
    """Plot the infomration plane with the epochs in diffrnet colors """
    #If we want to plot the train and test error
    # if plot_error:
    #     fig_strs = ['train_error','test_error','loss_train','loss_test' ]
    #     fig_data = [np.squeeze(gen_data[fig_str]) for fig_str in fig_strs]
    #     f1 = plt.figure(figsize=(12, 8))
    #     ax1 = f1.add_subplot(111)
    #     mean_sample = False if len(fig_data[0].shape)==1 else True
    #     if mean_sample:
    #         fig_data = [ np.mean(fig_data_s, axis=0) for fig_data_s in fig_data]
    #     for i in range(len(fig_data)):
    #         ax1.plot(epochsInds, fig_data[i],':', linewidth = 3 , label = fig_strs[i])
    #     ax1.legend(loc='best')
    f = plt.figure(figsize=(12, 8))
    axes = f.add_subplot(111)
    axes = np.array([[axes]])

    I_XT_array = np.squeeze(I_XT_array)
    I_TY_array = np.squeeze(I_TY_array)
    if len(I_TY_array[0].shape) > 1:
        I_XT_array = np.mean(I_XT_array, axis=0)
        I_TY_array = np.mean(I_TY_array, axis=0)
    max_index = size_ind if size_ind != -1 else (I_XT_array.shape[0] - 2)

    cmap = plt.get_cmap('gnuplot')
    #For each epoch we have diffrenet color
    if epochFlag:
        colors = [cmap(i) for i in np.linspace(0, 1, np.max(epochsInds) + 1)]
    else:
        colors = [cmap(i) for i in np.linspace(0, 1,
                                               np.max(epochsInds) + 1)][::-1]
    #Change this if we have more then one network arch
    nums_arc = -1
    #Go over all the epochs and plot then with the right color
    for index_in_range in range(0, max_index):
        XT = I_XT_array[index_in_range, :]
        TY = I_TY_array[index_in_range, :]
        #If this is the index that we want to emphsis
        # if epochsInds[index_in_range] ==index_to_emphasis:
        #     axes[index_i, index_j].plot(XT, TY, marker='o', linestyle=None, markersize=19, markeredgewidth=0.04,
        #                                 linewidth=2.1,
        #                                 color='g',zorder=10)
        # else:
        axes[index_i,
             index_j].plot(XT[:],
                           TY[:],
                           marker='o',
                           linestyle='-',
                           markersize=12,
                           markeredgewidth=0.01,
                           linewidth=0.2,
                           color=colors[int(epochsInds[index_in_range])])
    utils.adjustAxes(axes[index_i, index_j],
                     axis_font=axis_font,
                     title_str=title_str,
                     x_ticks=x_ticks,
                     y_ticks=y_ticks,
                     x_lim=[0, 25.1],
                     y_lim=None,
                     set_xlabel=index_i == axes.shape[0] - 1,
                     set_ylabel=index_j == 0,
                     x_label='$I(X;T)$',
                     y_label='$I(T;Y)$',
                     set_xlim=False,
                     set_ylim=False,
                     set_ticks=True,
                     label_size=font_size)
    #Save the figure and add color bar
    if index_i == axes.shape[0] - 1 and index_j == axes.shape[1] - 1:
        if epochFlag:
            utils.create_color_bar(f,
                                   cmap,
                                   colorbar_axis,
                                   bar_font,
                                   np.sort(epochsInds),
                                   title='Epochs')
        else:
            utils.create_color_bar(f,
                                   cmap,
                                   colorbar_axis,
                                   bar_font,
                                   np.sort(epochsInds)[::-1],
                                   title='Traces')
        f.savefig(save_name + '.png', dpi=500, format='png')