def inference_recursive_V3(l_input_path=None, conserve_nodes=None, paths=None, hyper=None, norm=1e-3): assert isinstance(conserve_nodes, list), 'conserve nodes should be a list' assert isinstance( l_input_path, list ), 'inputs is expected to be a list of images for heterogeneous image size!' assert isinstance(paths, dict), 'paths should be a dict' assert isinstance(hyper, dict), 'hyper should be a dict' from mpi4py import MPI # prevent GPU os.environ["CUDA_VISIBLE_DEVICES"] = "-1" communicator = MPI.COMM_WORLD rank = communicator.Get_rank() nb_process = communicator.Get_size() # optimize ckpt to pb for inference if rank == 0: for img in l_img_path: logger.debug(img) check_N_mkdir(paths['inference_dir']) freeze_ckpt_for_inference(paths=paths, hyper=hyper, conserve_nodes=conserve_nodes) optimize_pb_for_inference(paths=paths, conserve_nodes=conserve_nodes) reconstructor = reconstructor_V3_cls( image_size=load_img(l_input_path[0]).shape, z_len=len(l_input_path), nb_class=hyper['nb_classes'], maxp_times=hyperparams['maxp_times']) pbar1 = tqdm(total=len(l_input_path)) # ************************************************************************************************ I'm a Barrier communicator.Barrier() # reconstruct volumn remaining = len(l_input_path) nb_img_per_rank = remaining // (nb_process - 1) rest_img = remaining % (nb_process - 1) print(nb_img_per_rank, rest_img, nb_process) if rank == 0: # start gathering batches from other rank s = MPI.Status() communicator.Probe(status=s) while remaining > 0: if s.tag == tag_compute: # receive outputs slice_id, out_batch = communicator.recv(tag=tag_compute) logger.debug(slice_id) reconstructor.write_slice(out_batch, slice_id) # progress remaining -= 1 pbar1.update(1) else: try: if (rank - 1) < rest_img: start_id = (rank - 1) * (nb_img_per_rank + 1) id_list = np.arange(start_id, start_id + nb_img_per_rank + 1, 1) else: start_id = (rank - 1) * nb_img_per_rank + rest_img id_list = np.arange(start_id, start_id + nb_img_per_rank, 1) logger.debug('{}: {}'.format(rank, id_list)) _inference_recursive_V3(l_input_path=l_input_path, id_list=id_list, pb_path=paths['optimized_pb_path'], conserve_nodes=conserve_nodes, hyper=hyper, comm=communicator, maxp_times=hyperparams['maxp_times'], normalization=norm) except Exception as e: logger.error(e) MPI.MPI_abort(communicator) # ************************************************************************************************ I'm a Barrier communicator.Barrier() # save recon if rank == 0: recon = reconstructor.get_volume() for i in tqdm(range(len(l_input_path)), desc='writing data'): Image.fromarray( recon[i]).save(paths['inference_dir'] + 'step{}_{}.tif'.format(paths['step'], i))