def save_all_patches():
    out_dir = '../experiments/figures/primary_caps_viz/im_all_patches/train'
    util.makedirs(out_dir)
    _, test_file, convnet, imsize = get_caps_compiled()
    test_im = [
        scipy.misc.imread(line_curr.split(' ')[0])
        for line_curr in util.readLinesFromFile(test_file)
    ]

    for idx_test_im_curr, im_curr in enumerate(test_im):
        for x in range(6):
            for y in range(6):

                out_file_curr = os.path.join(
                    out_dir,
                    '_'.join([str(val)
                              for val in [idx_test_im_curr, x, y]]) + '.jpg')
                print out_file_curr
                rec_field, center = receptive_field.get_receptive_field(
                    convnet, imsize,
                    len(convnet) - 1, x, y)
                center = [int(round(val)) for val in center]
                range_x = [
                    max(0, center[0] - rec_field / 2),
                    min(imsize, center[0] + rec_field / 2)
                ]
                range_y = [
                    max(0, center[1] - rec_field / 2),
                    min(imsize, center[1] + rec_field / 2)
                ]
                patch = im_curr[range_y[0]:range_y[1], range_x[0]:range_x[1]]
                # print out_file_curr
                # raw_input()
                scipy.misc.imsave(out_file_curr, patch)
def k_means(caps,
            num_clusters,
            filter_num,
            x,
            y,
            test_im,
            out_dir_curr,
            out_file_html,
            convnet,
            imsize,
            rewrite=False):
    vec_rel_org = caps[:, filter_num, x, y, :]
    k_meaner = sklearn.cluster.KMeans(n_clusters=num_clusters)
    vec_rel = sklearn.preprocessing.normalize(vec_rel_org,
                                              axis=0)  #feature normalize
    vec_rel = vec_rel_org

    bins = k_meaner.fit_predict(vec_rel)
    print bins
    for val in np.unique(bins):
        print val, np.sum(bins == val)

    im_row = [[] for idx in range(num_clusters)]
    caption_row = [[] for idx in range(num_clusters)]
    for idx_idx, bin_curr in enumerate(bins):
        out_file_curr = os.path.join(out_dir_curr, str(idx_idx) + '.jpg')
        # if not os.path.exists(out_file_curr) or rewrite:
        im_curr = test_im[idx_idx]
        rec_field, center = receptive_field.get_receptive_field(
            convnet, imsize,
            len(convnet) - 1, x, y)
        center = [int(round(val)) for val in center]
        range_x = [
            max(0, center[0] - rec_field / 2),
            min(imsize, center[0] + rec_field / 2)
        ]
        range_y = [
            max(0, center[1] - rec_field / 2),
            min(imsize, center[1] + rec_field / 2)
        ]
        im_curr = im_curr[range_y[0]:range_y[1], range_x[0]:range_x[1]]
        # print out_file_curr
        # raw_input()
        scipy.misc.imsave(out_file_curr, im_curr)
        im_row[bin_curr].append(util.getRelPath(out_file_curr, dir_server))
        # print bin_curr,np.linalg.norm(vec_rel_org[idx_idx])
        caption_row[bin_curr].append(
            '%d %.4f' % (bin_curr, np.linalg.norm(vec_rel_org[idx_idx])))

    # out_file_html = out_dir_curr+'.html'
    visualize.writeHTML(out_file_html, im_row, caption_row, 40, 40)
    print out_file_html
    return im_row, caption_row
def save_ims(mags,
             filter_num,
             x,
             y,
             test_im,
             out_dir_curr,
             convnet,
             imsize,
             rewrite=False):
    vec_rel = mags[:, filter_num, x, y]
    print vec_rel.shape
    idx_sort = np.argsort(vec_rel)[::-1]
    print vec_rel[idx_sort[0]]
    print vec_rel[idx_sort[-1]]

    im_row = []
    caption_row = []
    for idx_idx, idx_curr in enumerate(idx_sort):
        out_file_curr = os.path.join(out_dir_curr, str(idx_idx) + '.jpg')
        if not os.path.exists(out_file_curr) or rewrite:
            im_curr = test_im[idx_curr]
            rec_field, center = receptive_field.get_receptive_field(
                convnet, imsize,
                len(convnet) - 1, x, y)
            center = [int(round(val)) for val in center]
            range_x = [
                max(0, center[0] - rec_field / 2),
                min(imsize, center[0] + rec_field / 2)
            ]
            range_y = [
                max(0, center[1] - rec_field / 2),
                min(imsize, center[1] + rec_field / 2)
            ]
            im_curr = im_curr[range_y[0]:range_y[1], range_x[0]:range_x[1]]
            # print out_file_curr
            # raw_input()
            scipy.misc.imsave(out_file_curr, im_curr)
        im_row.append(out_file_curr)
        caption_row.append('%d %.4f' % (idx_idx, vec_rel[idx_curr]))
    return im_row, caption_row
def make_ent_rec_win_accu_individual():
    ent_dir = '../experiments/figures/ck_routing'
    out_dir = '../experiments/figures_rebuttal/individual_ent_heat_map'
    util.mkdir(out_dir)

    labels = np.load(os.path.join(ent_dir,'labels.py.npy'))
    preds = np.load(os.path.join(ent_dir,'preds.py.npy'))

    test_pre = '../data/ck_96/train_test_files/test_'
    im_files = []
    annos = []
    for num in range(10):
        lines = util.readLinesFromFile(test_pre+str(num)+'.txt')
        im_files = im_files+[line_curr.split(' ')[0] for line_curr in lines]
        annos = annos +  [int(line_curr.split(' ')[1]) for line_curr in lines]
    im_files = np.array(im_files)
    annos = np.array(annos)


    num_emos = 8
    emo_strs = ['Neutral','Anger', 'Contempt','Disgust', 'Fear', 'Happiness', 'Sadness', 'Surprise']
    
    emo_idx = 1
    emo_str = 'Anger'
    _, _, convnet, imsize  = pcv.get_caps_compiled()

    for emo_idx, emo_str in enumerate(emo_strs):
        out_dir_curr = os.path.join(out_dir,emo_str.lower())
        util.mkdir(out_dir_curr)

        rel_idx = np.logical_and(labels==emo_idx,preds==labels)
        rel_im = im_files[rel_idx]
        

        ent_file = os.path.join(ent_dir,'_'.join([str(val) for val in [emo_idx,emo_idx,1]])+'.npy')
        ent = np.load(ent_file)
        for idx,mean_ent in enumerate(ent):
            print mean_ent.shape
            print np.mean(mean_ent),np.max(mean_ent)

            mean_ent = mean_ent - np.min(mean_ent)
            mean_ent = mean_ent/np.max(mean_ent)
            mean_ent = 1 - mean_ent
    
            accu = np.zeros((96,96))
            count = np.zeros((96,96))
            
    #     
    #     print imsize

            for x in range(1,6):
                for y in range(1,6):
                    rec_field, center = receptive_field.get_receptive_field(convnet,imsize,len(convnet)-1, x,y)
                    center = [int(round(val)) for val in center]
                    range_x = [max(0,center[0]-rec_field/2),min(imsize,center[0]+rec_field/2)]
                    range_y = [max(0,center[1]-rec_field/2),min(imsize,center[1]+rec_field/2)]
                    # print range_x,range_y

                    accu[range_y[0]:range_y[1],range_x[0]:range_x[1]] = accu[range_y[0]:range_y[1],range_x[0]:range_x[1]]+mean_ent[y,x]

                    count[range_y[0]:range_y[1],range_x[0]:range_x[1]] = count[range_y[0]:range_y[1],range_x[0]:range_x[1]]+1
        
            count[count==0]=1.
            heat_map = accu/count
            file_pre = os.path.join(out_dir_curr,str(idx)+'_')
            # out_file = os.path.join(out_dir,emo_str.lower()+'.png')
            
            # visualize.plot_colored_mats(out_file,heat_map,0,1, title=emo_str)
            heat_map = visualize.getHeatMap(heat_map)*255

            avg_im = scipy.misc.imread(rel_im[idx])
            # scipy.misc.imread(os.path.join(in_dir_mean,emo_str.lower()+'.png'))
            avg_im = avg_im[:,:,np.newaxis]
            avg_im = np.concatenate([avg_im,avg_im,avg_im],2)
    #     # print avg_im.shape, np.min(avg_im),np.max(avg_im)
        
    #     for alpha in np.arange(0.1,1.0,0.1):
            alpha = 0.5
            out_file_heat = file_pre+'_heat_'+str(alpha)+'.png'
            print out_file_heat
            visualize.fuseAndSave(avg_im,heat_map,alpha,out_file_curr=out_file_heat)
    
        visualize.writeHTMLForFolder(out_dir_curr,'.png')
        raw_input()
def make_ent_rec_win_accu():
    ent_dir = '../experiments/figures/ck_routing'
    out_dir = '../experiments/figures_rebuttal/ent_heat_map'
    in_dir_mean = '../data/ck_96/mean_expressions'
    util.mkdir(out_dir)

    num_emos = 8
    emo_strs = ['Neutral','Anger', 'Contempt','Disgust', 'Fear', 'Happiness', 'Sadness', 'Surprise']
    
    emo_idx = 1
    emo_str = 'Anger'
    for emo_idx, emo_str in enumerate(emo_strs):

        ent_file = os.path.join(ent_dir,'_'.join([str(val) for val in [emo_idx,emo_idx,1]])+'.npy')
        ent = np.load(ent_file)
        mean_ent = np.mean(ent,axis=0)
        mean_ent = mean_ent - np.min(mean_ent)
        mean_ent = mean_ent/np.max(mean_ent)
        print np.min(mean_ent),np.max(mean_ent)
        mean_ent = 1 - mean_ent
        print np.min(mean_ent),np.max(mean_ent)

        accu = np.zeros((96,96))
        count = np.zeros((96,96))
        
        _, _, convnet, imsize  = pcv.get_caps_compiled()
        print imsize

        for x in range(1,6):
            for y in range(1,6):
                rec_field, center = receptive_field.get_receptive_field(convnet,imsize,len(convnet)-1, x,y)
                center = [int(round(val)) for val in center]
                range_x = [max(0,center[0]-rec_field/2),min(imsize,center[0]+rec_field/2)]
                range_y = [max(0,center[1]-rec_field/2),min(imsize,center[1]+rec_field/2)]
                # print range_x,range_y

                accu[range_y[0]:range_y[1],range_x[0]:range_x[1]] = accu[range_y[0]:range_y[1],range_x[0]:range_x[1]]+mean_ent[y,x]

                count[range_y[0]:range_y[1],range_x[0]:range_x[1]] = count[range_y[0]:range_y[1],range_x[0]:range_x[1]]+1
        
        count[count==0]=1.
        heat_map = accu/count
        out_file = os.path.join(out_dir,emo_str.lower()+'.png')
        
        # print out_file
        # print heat_map.shape, np.min(heat_map),np.max(heat_map)
        visualize.plot_colored_mats(out_file,heat_map,0,1, title=emo_str)
        heat_map = visualize.getHeatMap(heat_map)*255

        # print 'heat_map',heat_map.shape, np.min(heat_map), np.max(heat_map)


        avg_im = scipy.misc.imread(os.path.join(in_dir_mean,emo_str.lower()+'.png'))
        avg_im = avg_im[:,:,np.newaxis]
        avg_im = np.concatenate([avg_im,avg_im,avg_im],2)
        # print avg_im.shape, np.min(avg_im),np.max(avg_im)

        for alpha in np.arange(0.1,1.0,0.1):
            out_file_heat = os.path.join(out_dir,emo_str.lower()+'_heat_'+str(alpha)+'.png')
            visualize.fuseAndSave(avg_im,heat_map,0.5,out_file_curr=out_file_heat)
        # ,max_val=255)
        # print heat_map.shape, np.min(heat_map),np.max(heat_map)
        # scipy.misc.imsave(out_file_heat,heat_map)

        # raw_input()

    visualize.writeHTMLForFolder(out_dir,'.png')
                

    print mean_ent.shape
def plot_specific_patches():
    dirs_rel = get_ck_16_dirs()
    
    test_pre = '../data/ck_96/train_test_files/test_'
    im_files = []
    for num in range(10):
        lines = util.readLinesFromFile(test_pre+str(num)+'.txt')
        im_files = im_files+[line_curr.split(' ')[0] for line_curr in lines]
    print len(im_files)
    print im_files[0]


    out_dir = '../experiments/figures/ck_routing'
    out_dir_im_meta = '../experiments/figures_rebuttal/ck_routing'

    util.makedirs(out_dir)
    util.makedirs(out_dir_im_meta)


    mats_names = ['labels','preds','routes_0','routes_1']
    mat_arrs = [[] for name in mats_names]
    for dir_curr in dirs_rel:
        for idx_mat_name,mat_name in enumerate(mats_names):
            arr_curr_file = os.path.join(dir_curr,mat_name+'.npy')
            arr_curr = np.load(arr_curr_file)
            mat_arrs[idx_mat_name].append(arr_curr)

    # mat_arrs = [np.concatenate(mat_arr,0) for mat_arr in mat_arrs]
    axis_combine = [0,0,1,1]
    mat_arrs = [np.concatenate(mat_arr,axis_curr) for mat_arr,axis_curr in zip(mat_arrs,axis_combine)]
    for idx_mat_arr,mat_arr in enumerate(mat_arrs):
        print mat_arr.shape,len(im_files)

    # print mat_arrs[0][:10],mat_arrs[1][:10]
    accuracy = np.sum(mat_arrs[0]==mat_arrs[1])/float(mat_arrs[0].size)
    print 'accuracy',accuracy

    # print mat_arrs
    routes_all = mat_arrs[2:]
    # print len(routes_all)
    # raw_input()
    num_emos = 8
    emo_strs = ['Neutral','Anger', 'Contempt','Disgust', 'Fear', 'Happiness', 'Sadness', 'Surprise']
    _, _, convnet, imsize  = pcv.get_caps_compiled()

    print convnet,imsize

    # tuples_to_save = [(3,1),(3,3),(1,0)]
    # arr_emo = (1,[(3,1),(3,3),(1,0)])
    arr_emo = (3,[(2,3),(3,2),(4,3)])

    for label_curr,tuples_to_save in [arr_emo]:
        # label_curr = 1
        label_compare = label_curr

        out_dir_im = os.path.join(out_dir_im_meta,emo_strs[label_curr])
        util.mkdir(out_dir_im)
        print out_dir_im
        # raw_input()



        idx_keep = np.logical_and(mat_arrs[0]==label_curr,mat_arrs[0]==mat_arrs[1])
        files_keep = [im_curr for idx_im_curr,im_curr in enumerate(im_files) if idx_keep[idx_im_curr]]
        
        out_file_html = os.path.join(out_dir_im,'patches.html')
        html_rows = []
        caption_rows = []
        for x,y in tuples_to_save:
            html_row = []
            caption_row = []

            for idx_test_im_curr,test_im in enumerate(files_keep):
                im_curr = scipy.misc.imread(test_im)

                out_file_curr = os.path.join(out_dir_im,'_'.join([str(val) for val in [idx_test_im_curr,x,y]])+'.jpg')
                # print out_file_curr
                # raw_input()
                
                rec_field, center = receptive_field.get_receptive_field(convnet,imsize,len(convnet)-1, x,y)
                center = [int(round(val)) for val in center]
                range_x = [max(0,center[0]-rec_field/2),min(imsize,center[0]+rec_field/2)]
                range_y = [max(0,center[1]-rec_field/2),min(imsize,center[1]+rec_field/2)]
                # print range_x
                # raw_input()
                patch = im_curr[range_y[0]:range_y[1],range_x[0]:range_x[1]]
                # print out_file_curr
                # raw_input()
                scipy.misc.imsave(out_file_curr,patch)
                html_row.append(util.getRelPath(out_file_curr.replace(str_replace[0],str_replace[1]),dir_server))
                caption_row.append(' '.join([str(val) for val in [idx_test_im_curr,x,y]]))
            html_rows.append(html_row)
            caption_rows.append(caption_row)

        visualize.writeHTML(out_file_html,html_rows,caption_rows,40,40)
def pca(caps,
        num_clusters,
        filter_num,
        x,
        y,
        test_im,
        out_dir_curr,
        out_file_html,
        convnet,
        imsize,
        rewrite=False):
    vec_rel = caps[:, filter_num, x, y, :]
    # pca = sklearn.decomposition.PCA(n_components=8, whiten = True)
    # vec_rel = sklearn.preprocessing.normalize(vec_rel_org,axis=0) #feature normalize
    # pca.fit(vec_rel_org)
    # print pca.explained_variance_ratio_  , np.sum(pca.explained_variance_ratio_)
    # vec_rel = pca.transform(vec_rel_org)
    # print vec_rel.shape
    im_rows = []
    caption_rows = []
    for vec_curr_idx in range(vec_rel.shape[1]):
        directions = vec_rel[:, vec_curr_idx]
        # directions = vec_rel/np.linalg.norm(vec_rel,axis=1,keepdims=True)
        # directions = np.arctan(directions[:,0]/directions[:,1])
        # print np.min(directions), np.max(directions)
        idx_sort = np.argsort(directions)

        # print vec_rel.shape

        # plt.figure()
        # plt.plot(directions[:,0],directions[:,1],'*b')
        # plt.savefig(out_dir_curr+'.jpg')
        # plt.close()
        # raw_input()

        im_row = []
        # [] for idx in range(num_clusters)]
        caption_row = []
        # [] for idx in range(num_clusters)]
        for idx_idx, idx_curr in enumerate(idx_sort):
            out_file_curr = os.path.join(out_dir_curr, str(idx_idx) + '.jpg')
            # if not os.path.exists(out_file_curr) or rewrite:
            im_curr = test_im[idx_curr]
            rec_field, center = receptive_field.get_receptive_field(
                convnet, imsize,
                len(convnet) - 1, x, y)
            center = [int(round(val)) for val in center]
            range_x = [
                max(0, center[0] - rec_field / 2),
                min(imsize, center[0] + rec_field / 2)
            ]
            range_y = [
                max(0, center[1] - rec_field / 2),
                min(imsize, center[1] + rec_field / 2)
            ]
            im_curr = im_curr[range_y[0]:range_y[1], range_x[0]:range_x[1]]
            # print out_file_curr
            # raw_input()
            scipy.misc.imsave(out_file_curr, im_curr)
            im_row.append(util.getRelPath(out_file_curr, dir_server))
            # [bin_curr].append(util.getRelPath(out_file_curr,dir_server))
            # print bin_curr,np.linalg.norm(vec_rel_org[idx_idx])
            caption_row.append('%d %.2f' % (idx_curr, directions[idx_curr]))

        im_rows.append(im_row)
        caption_rows.append(caption_row)
    # out_file_html = out_dir_curr+'.html'
    visualize.writeHTML(out_file_html, im_rows, caption_rows, 40, 40)
    print out_file_html