def assign(Metric, project_filepath, cutoff, centerfile, flag_nopreprocess):
    ''' Method to add rest of the data to the cluster centers 
    '''
    # Initialize MPI.
    comm = MPI.COMM_WORLD
    mpi_size = comm.Get_size()
    my_rank = comm.Get_rank()
    comm.Barrier()
 
    # Print only at node 0.
#     def my_print(x):
#         print x
#     print0 = lambda x: my_print(x) if my_rank == 0 else None
    logger.info("refine start")
    # Say hello.
    print0(rank=my_rank,msg=" Initialized MPI.")
    logger.debug("Hello, from node %s",my_rank)
 
    # Load project file.
    print0(rank=my_rank,msg=" Reading project yaml file.")
    project = Project(existing_project_file = project_filepath)
 
    logger.debug("Initializing manager at node %s",my_rank)
    manager = Loadmanager(project.get_trajectory_lengths(),
                          project.get_trajectory_filepaths(),
                          mpi_size,my_rank)

    # Instantiate Metric class.
    if my_rank == 0:
        metric = Metric(tpr_filepath = project.get_tpr_filepath(),
                       stx_filepath = project.get_gro_filepath(),
                       ndx_filepath = project.get_ndx_filepath(),
                       number_dimensions = project.get_number_dimensions() )
        
        metric.destroy_pointers()
    else:
        metric = None
 
    metric = comm.bcast(metric, root=0)
    print0(rank=my_rank,msg="metric object broadcasted.")
 
    metric.create_pointers()
    
    manager.do_partition()
    # Take work share.
    my_partition = manager.myworkshare

 
    (my_trajectory_filepaths, my_trajectory_lengths, \
        my_trajectory_ID_offsets, my_trajectory_ID_ranges) = \
        map(list, my_partition)
 
    logger.info("Reading trajectories at %s",my_rank)
    my_frames = Framecollection.from_files(
            stride = 1,
            trajectory_type = project.get_trajectory_type(),
            trajectory_globalID_offsets = my_trajectory_ID_offsets,
            trajectory_filepath_list = my_trajectory_filepaths,
            trajectory_length_list = my_trajectory_lengths, )

        
    if flag_nopreprocess == False:
        # ----------------------------------------------------------------
        # Preprocess trajectories (modifying them in-place).
        # Metric preprocessing.
        logger.debug(" Preprocessing trajectories at rank %s",my_rank)
        metric.preprocess( frame_array_pointer = my_frames.get_first_frame_pointer(),
                           number_frames = my_frames.number_frames,
                           number_atoms = my_frames.number_atoms)
    else:
        
        print0(rank=my_rank,msg=" Will not preprocess trajectories")

         
    # ----------------------------------------------------------------
    # Preprocess trajectories (modifying them in-place).
    # Metric preprocessing.
#    print0(my_rank,"[Cluster] Preprocessing trajectories (for Metric).")
#    my_metric.preprocess(   frame_array_pointer = my_frames.get_first_frame_pointer(),
#                            number_frames = my_frames.number_frames(),
#                            number_atoms = my_frames.number_atoms(), )
 
    clustercenters = txtreader.readcols(centerfile)[:,1]
 
    clusters = {} # :: FrameID (cluster center) -> [FrameID] (the cluster -- its list of frames)
    my_unclustered = set([i for i in my_frames.globalIDs_iter])
    removed_vertices = set()
 
    for center_id in clustercenters:
        center_host_node = manager.find_node_of_frame(center_id)
        if center_host_node is None:
            raise KeyError("Next cluster center ID not found within any node.")
             
        
        # Broadcasting of center.
        if my_rank == center_host_node:
            center_frame = my_frames.get_frame(center_id)
        else:
            shape = (my_frames.number_atoms, 3)
            center_frame = np.empty(shape, dtype=my_frames.frames.dtype)
        
        comm.Bcast([center_frame, my_frames.mpi_frametype], root=center_host_node)
        if my_rank == 0 : logger.debug("Searching for members for center id %s", center_id)
        
        center_frame_pointer = center_frame.ctypes.data_as(ctypes.POINTER(gp_grompy.rvec))
        rmsd_buffer = np.empty(my_frames.number_frames, dtype=my_frames.frames.dtype)
        
        metric.compute_distances( 
            reference_frame_pointer = center_frame_pointer,
            frame_array_pointer = my_frames.get_first_frame_pointer(),
            number_frames = my_frames.number_frames,
            number_atoms = my_frames.number_atoms,
            real_output_buffer = rmsd_buffer, # writes results to this buffer.
            mask_ptr = None,
            mask_dummy_value = -1.0,
            )
         
        fst = lambda x: x[0]
        existsAndWithinCutoff = lambda x: (x[0] not in removed_vertices) and (0.0 <= x[1] <= cutoff)
        my_members = map(fst, filter(existsAndWithinCutoff,
                        zip(my_frames.globalIDs_iter, rmsd_buffer))) # for striding.
        
        # Broadcasting of members.
        members_gathered = comm.allgather(my_members)
        members = list(itertools.chain(*members_gathered))
        if my_rank == 0 : logger.debug("Found %s members",len(members))
        removed_vertices.update(members)
        clusters[center_id] = list(members)
        
        my_unclustered = my_unclustered.difference(set(members))
        
    unclustered_gathered = comm.allgather(my_unclustered)
    unclustered = list(itertools.chain(*unclustered_gathered))
    logger.debug("Unclustered %s", unclustered)
    
    for i in unclustered:
        clusters[i]=[i]

    logger.info("refine end")
         
    return clusters 
def cluster(Metric, project_filepath, cutoff, checkpoint_filepath=None,
            flag_nopreprocess = False):
    """ Cluster data
    
    Parameters
    ---------- 
    project_filepath : String
       The path to the YAML project file.
       
    cutoff           : Floating 
        The cutoff distance passed to the Metric class.
        
    chekpoint_filepath : string,optional
    flag_nopreprocess : bool,optional
        switches off preprocessing
    
    Returns
    -------
    clusters: dict
        a map (dictionary) from cluster center vertices C to lists of vertices,
        where the lists represent the set of vertices belonging to the cluster with center C.
        i.e. a map: FrameID -> [FrameID]
        
    """

    # ================================================================
    # Instantiation of helper classes.
    
    # Initialize MPI.
    comm = MPI.COMM_WORLD
    mpi_size = comm.Get_size()
    my_rank = comm.Get_rank()
    comm.Barrier()
    # Print only at node 0.
#     def my_print(x):
#         print x
#     print0 = lambda x: my_print(x) if my_rank == 0 else None

    # Say hello.
    print0(rank=my_rank,msg="Initialized MPI.")
    logger.debug("Hello, from node %s",my_rank)

    logger.info("Reading project file at node %s",my_rank)
    project = Project(existing_project_file = project_filepath)
    
    logger.debug("Initializing manager at node %s",my_rank)
    manager = Loadmanager(project.get_trajectory_lengths(),
                          project.get_trajectory_filepaths(),
                          mpi_size,my_rank)
    # metric has to be instantiated by only one mpi process as it needs user input
    # ie index groups
    if my_rank == 0:
         
        metric = Metric(tpr_filepath = project.get_tpr_filepath(),
                       stx_filepath = project.get_gro_filepath(),
                       ndx_filepath = project.get_ndx_filepath(),
                       number_dimensions = project.get_number_dimensions() )
        # since we have to broadcast it we need to destroy all pointers to arrays
        metric.destroy_pointers()
        logger.debug("Metric initialized at node 0")
    else:
        metric = None
    metric =  comm.bcast(metric, root = 0)
    print0(rank=my_rank,msg="metric object broadcasted.")
    # recreate all pointers in the object's instance
    metric.create_pointers()
    
    manager.do_partition()
    
    
    # Take work share.
    my_partition = manager.myworkshare
    
    (my_trajectory_filepaths, my_trajectory_lengths, \
     my_trajectory_ID_offsets, my_trajectory_ID_ranges) = \
     map(list, my_partition)


    #print0(my_rank,"\tDistribution: {0}".format(frame_globalID_distribution))
    logger.info("Reading trajectories at %s",my_rank)
    my_frames = Framecollection.from_files(
            stride = project.get_stride(),
            trajectory_type = project.get_trajectory_type(),
            trajectory_globalID_offsets = my_trajectory_ID_offsets,
            trajectory_filepath_list = my_trajectory_filepaths,
            trajectory_length_list = my_trajectory_lengths, )

        
    if flag_nopreprocess == False:
        # ----------------------------------------------------------------
        # Preprocess trajectories (modifying them in-place).
        # Metric preprocessing.
        logger.debug(" Preprocessing trajectories at rank %s",my_rank)
        metric.preprocess( frame_array_pointer = my_frames.get_first_frame_pointer(),
                           number_frames = my_frames.number_frames,
                           number_atoms = my_frames.number_atoms)
    else:
        
        print0(rank=my_rank,msg=" Will not preprocess trajectories")


    # ================================================================
    # Initial round of all-to-all neighbour counting.


    # Count the number of neighbours for all frames.
    # If frames are vertices and edges join frames having rmsd within the cutoff,
    # then we compute and record the degree of each vertex.

    if checkpoint_filepath is None:
        print0(rank=my_rank,msg="Counting 'neighbours' for all frames.")

        my_neighbour_count = allToAll_neighbourCount(cutoff, comm, mpi_size, my_rank,
                                metric, my_frames,manager) # :: Map Integer Integer

        print0(rank=my_rank,msg="Synchronizing neighbour counts.")
        neighbour_count_recvList = comm.allgather(my_neighbour_count)

        neighbour_counts = {}
        for node_neighbour_counts in neighbour_count_recvList:
            for frameID in node_neighbour_counts:
                try:
                    neighbour_counts[frameID] += node_neighbour_counts[frameID]
                except KeyError:
                    neighbour_counts[frameID]  = node_neighbour_counts[frameID]
    else :
        print0(rank=my_rank,msg="Using checkpoint file.")
        neighbour_counts = None



    print0(rank=my_rank,msg="Start clustering.")
    
    T=time()

    clusters = daura_clustering(neighbour_counts,
                    cutoff, comm, mpi_size, my_rank, manager, 
                    metric, my_frames, checkpoint_filepath)
    
    print0(rank=my_rank,msg=" Finished ... Total time: {0}".format(time()-T))

                    
    return clusters