def query_syms_on_mat(query_syms, trans_mat, syms, configs, topn=7):
    trans_subset = []
    top_inds = []
    # collect top inds
    for i, sym in enumerate(query_syms):
        ind = syms.index(sym)
        trans = trans_mat[ind, :]
        sorted_inds = np.argsort(-trans)[:topn]
        print 'sorted_inds', sorted_inds
        print np.argsort(-trans)
        for ii in sorted_inds:
            top_inds.append(ii)
    top_inds = sorted(set(top_inds), key=top_inds.index)

    for i, sym in enumerate(query_syms):
        ind = syms.index(sym)
        trans = [trans_mat[ind, ii] for ii in top_inds]
        trans_subset.append(trans)
    top_syms = [syms[ii] for ii in top_inds]
    trans_subset = np.asarray(trans_subset)
    print 'before trans_subset', trans_subset.shape
    if len(trans_subset.shape) > 2:
        trans_subset = np.squeeze(trans_subset)
    print 'trans_subset', trans_subset.shape

    plot_mat(trans_subset, '',
             top_syms, x_tick_syms=query_syms)
    fname_tag = '_'.join(query_syms)
    fname_tag = fname_tag.replace('/', '_')
    plt.savefig('trans-%s-%s.pdf' % (fname_tag,
                                     configs.name))
    return trans_subset, top_syms
def query_syms_on_matrices(query_syms, trans_matrices, trans_names,
                           syms, configs, topn=8):
    top_syms_list = []
    for i, trans_mat in enumerate(trans_matrices):
        _, top_syms = query_syms_on_mat(query_syms, trans_mat, syms,
                                        configs, topn=topn)
        top_syms_list.extend(top_syms)
        print trans_names[i], top_syms
    top_syms_list = sorted(set(top_syms_list), key=top_syms_list.index)

    from subset_tools import subset
    matrices = None
    query_syms_list = []
    for query_sym in query_syms:
        for i, trans_mat in enumerate(trans_matrices):
            mat = subset(trans_mat, syms, [query_sym], top_syms_list)
            if matrices is None:
                matrices = mat
            else:
                matrices = np.vstack((matrices, mat))
            sym_name = '%s-%s' % (trans_names[i], query_sym)
            query_syms_list.append(sym_name)
    matrices = np.squeeze(matrices)
    print 'matrices', matrices.shape

    plot_mat(matrices, '', top_syms_list, query_syms_list)
    fname_tag = '_'.join(query_syms)
    fname_tag = fname_tag.replace('/', '_')
    plt.savefig('trans-both-%s-%s.pdf' % (fname_tag,
                                          configs.name))