示例#1
0
def main():
    print('\n')
    print('###############################################')
    print('###############################################')
    print('####                                       ####')
    print('####  RUN ANALYSIS OF THE RECONSTRUCTIONS  ####')
    print('####                                       ####')      
    print('###############################################')
    print('###############################################')  
    print('\n') 


    
    ##  Get input arguments
    args = getArgs()


    
    ##  Get input path
    pathin = args.pathin
    
    if pathin.find(':') == -1:
        sys.exit('\nERROR: path to input reconstructions with odd indeces' +
                 ' must be separated with a ":" from the reconstructions with' +
                 ' even indeces')
    
    else:
        pathin = pathin.split(':')
        if pathin[0][len(pathin[0])-1] != '/':
            pathin[0] += '/'
        if pathin[1][len(pathin[1])-1] != '/':
            pathin[1] += '/' 
    
    print('\nInput paths:\n', pathin[0],'\n', pathin[1])


        
    ##  Get reconstruction files
    label = args.label

    curr_dir = os.getcwd()
    os.chdir( pathin[0] )

    rec_list1 = sorted( glob.glob( '*' + label + '*' ) )
    num_files1 = len( rec_list1 )

    os.chdir( curr_dir )
    os.chdir( pathin[1] )

    rec_list2 = sorted( glob.glob( '*' + label + '*' ) )
    num_files2 = len( rec_list2 )

    os.chdir( curr_dir ) 


    if num_files1 != num_files2:
        sys.exit('\nERROR: number of files in path1 = ' + str( num_files1 ) 
                  + ' differ from number of files in path2 = ' + str( num_files2 ) )
    else:
        num_files = num_files1
        print('\nNumber of input files in both paths: ', num_files)



    ##  Organize input files
    plot_axes = args.plot_axes
    axes = []
    file_label = []
    
    
    if plot_axes.find(',') == -1:
        flag_plot3d = 0
        
        plot_axes_aux = plot_axes_aux.split(':')
        file_label.append( plot_axes_aux[0] )

        for j in range( 1 , len( plot_axes_aux ) ):
            axes.append( plot_axes_aux[j] )         
        
        print('\nPlot 2D selected using file label: ', file_label)
        print('Axis x values: ', axes )

    else:
        plot_axes = plot_axes.split(',')
        flag_plt3d = 1

        for i in range( len( plot_axes ) ):
            plot_axes_aux = plot_axes[i]
            plot_axes_aux = plot_axes_aux.split(':')
            file_label.append( plot_axes_aux[0] )

            axes_values = []
            for j in range( 1 , len( plot_axes_aux ) ):
                axes_values.append( plot_axes_aux[j] )
            axes.append( axes_values )
        
        print('\nPlot 3D selected using file labels: ', file_label)
        print('Axis x values: ', axes[0])
        print('Axis y values: ', axes[1])

    xlen = len( axes[0] )
    ylen = len( axes[1] )
    x = np.array( axes[0] ).astype( myFloat )
    y = np.array( axes[1] ).astype( myFloat ) 
    result = np.zeros( ( xlen , ylen ) , dtype=myFloat )

    axes_label = args.axes_label
    if axes_label.find(':') == -1:
        flag_plot3d = 0
        print('\nPlot 2D selected using file label: ', axes_label)
    else:
        flag_plot3d = 1
        axes_label = axes_label.split(':')
        print('\nPlot 3D selected using file labels: ', axes_label)



    ##  Perform analisys
    if args.resol_square is True:
        print('Enabled analisys inside resolution square')

    print('\n\n')

    
    f = 0

    for i in range( xlen ):
        for j in range( ylen ):

            filein = rec_list1[f]
            image1 = io.readImage( pathin[0] + filein )
            print('\n\nInput image 1:\n', filein)

            filein = rec_list2[f]
            image2 = io.readImage( pathin[1] + filein )
            print('\nInput image 2:\n', filein)

            print('\n')


            ##  Preprocess input image and oracle
            if args.resol_square is True:
                image1 = an.select_resolution_square( image1 )
                image2 = an.select_resolution_square( image2 ) 


            ##  Perform analysis
            frc , resol  = an.fourier_ring_corr( image1 , image2 ) 
            result[i,j] = resol
            #x = np.arange( len( frc ) )
            #plt.plot( x , frc )
            #plt.show()

            f += 1



    ##  Write data for bar plot 3D to text file
    if  result.shape[0] * result.shape[1] != len( x ) * len( y ):
        sys.exit('\nERROR: number of results = ' + str( len( result ) ) +
                 ' does not match the product between number of x values = '
                 + str( len( x ) ) + ' and the number of the y values = '
                 + str( len( y ) ) + '\n')


    fileout = args.fileout
    
    if os.path.isfile( fileout ) is True:
        sys.exit('\nERROR: output file:\n' + fileout + '\nalready exists!')

    print('\nWriting data for bar plot 3D to file:\n', fileout)
    fout = open( args.fileout , 'w' )

    fout.write('x:' + axes_label[0].upper() + ':' + str( len( x ) ) )
    for i in range( len( x ) ):
        fout.write('\n%4f' % x[i])

    fout.write('\ny:' + axes_label[1].upper() + ':' + str( len( y ) ) )
    for i in range( len( y ) ):
        fout.write('\n%4f' % y[i])

    fout.write('\nz:FRC:' + str( len( result.flatten() ) ) )
    for j in range( result.shape[1] ):
        for i in range( result.shape[0] ): 
            fout.write('\n%4f' % result[i,j])

    fout.close()


    #print( x )
    #print( y )
    #print( result )


    from mpl_toolkits.mplot3d import Axes3D

    fig = plt.figure()
    ax = Axes3D(fig)    

    lx= len(result[0])            # Work out matrix dimensions
    ly= len(result[:,0])
    xpos = np.arange(0,lx,1)    # Set up a mesh of positions
    ypos = np.arange(0,ly,1)
    xpos, ypos = np.meshgrid(xpos+0.25, ypos+0.25)

    xpos = xpos.flatten()   # Convert positions to 1D array
    ypos = ypos.flatten()
    zpos = np.zeros(lx*ly)

    dx = 0.5 * np.ones_like(zpos)
    dy = dx.copy()
    dz = result.flatten()
    ax.bar3d(xpos,ypos,zpos, dx, dy, dz, color='b', alpha=0.8)    
    column_names = axes[1]
    row_names = axes[0]
    ax.w_xaxis.set_ticklabels(column_names)
    ax.w_yaxis.set_ticklabels(row_names)
    ax.set_xlabel( axes_label[1].upper() )
    ax.set_ylabel( axes_label[0].upper() )
    ax.set_zlabel( 'FRC' )     
    plt.show()

    
    print('\n')
def main():
    print('\n')
    print('###############################################')
    print('###############################################')
    print('####                                       ####')
    print('####  RUN ANALYSIS OF THE RECONSTRUCTIONS  ####')
    print('####                                       ####')      
    print('###############################################')
    print('###############################################')  
    print('\n') 


    
    ##  Get input arguments
    args = getArgs()


    
    ##  Get input path
    pathin = args.pathin
    
    if pathin[len(pathin)-1] != '/':
        pathin += '/'
    
    print('\nInput path:\n', pathin)


        
    ##  Get reconstruction files
    if args.imagein is not None:
        imagein = args.imagein
        image_file1 = imagein.split(':')[0]
        image_file2 = imagein.split(':')[1]
        num_pair = 1

    
    elif args.allin:
        curr_dir = os.getcwd()
        os.chdir( pathin )

        image_file1 = sorted( glob.glob( '*odd*' ) )
        image_file2 = sorted( glob.glob( '*even*' ) )

        if len( image_file1 ) != len( image_file2 ):
            sys.exit('\nERROR: number of labelled "odd" = ' + str( len( image_file1 ) )
                     + ' dies not correspond to the number of labelled "even" = ' 
                     + str( len( image_file1 ) ) )
        
        num_pair = len( image_file1 )



    ##  Create output names
    if num_pair == 1:
        fileout = args.fileout

    else:
        aux_out = args.fileout
        aux_out = aux_out.split(':')
        pathout = aux_out[0]
        if pathout[len(pathout)-1] != '/':
            pathout += '/'
        alg = aux_out[1]
        nproj = aux_out[2:]
        base_name = 'data_plot2d_frc_' + alg + '_proj_'
        if len( nproj ) != num_pair:
            sys.exit('\nERROR: number of image pair = ' + str( num_pair ) 
                     + ' does not match with number of labels = ' + 
                     + str( nproj ) )


    ##  Perform analisys
    if args.resol_square is True:
        print('Enabled analisys inside resolution square')

    print('\n\n')



    for i in range( num_pair ):
        ##  Read images
        if num_pair == 1:
            image1 = io.readImage( pathin + image_file1 )
            print('\n\nInput image 1:\n', image_file1 )

            image2 = io.readImage( pathin + image_file2 )
            print('\nInput image 2:\n', image_file2 )

        else:
            image1 = io.readImage( pathin + image_file1[i] )
            print('\n\nInput image 1:\n', image_file1[i] )

            image2 = io.readImage( pathin + image_file2[i] )
            print('\nInput image 2:\n', image_file2[i] )

        print('\n') 


        ##  Preprocess input image and oracle
        if args.resol_square is True:
            image1 = an.select_resolution_square( image1 )
            image2 = an.select_resolution_square( image2 ) 


        ##  Perform analysis
        frc , resol , x  = an.fourier_ring_corr( image1 , image2 )


        ##  Write output file
        if num_pair > 1:
            fileout = pathout + base_name + nproj[i] + '.txt'
    
        if os.path.isfile( fileout ) is True:
            sys.exit('\nERROR: output file:\n' + fileout + '\nalready exists!')

        print('\nWriting data for bar plot 3D to file:\n', fileout)
        fout = open( fileout , 'w' )

        fout.write('x:' +  ':' + str( len( x ) ) )
        for i in range( len( x ) ):
            fout.write('\n%4f' % x[i])

        fout.write('\ny:FRC:' + str( len( frc ) ) )
        for i in range( len( frc ) ): 
            fout.write('\n%4f' % frc[i])

        fout.close()


        #print( x )
        #print( resol )
        #print( frc )



        ##  Plot FRC curve
        if num_pair == 1:
            fig = plt.figure() 
            plt.plot( x , frc ,  'o' , color='b' , markersize=12 )
            plt.xlabel( 'Spatial frequencies' , fontsize=20 , position=(0.5,-0.2) )    
            plt.ylabel( 'FRC' , fontsize=20 , position=(0.5,0.5) )      
            plt.show()

    
    print('\n')