Esempio n. 1
0
def inference_recursive_V3(l_input_path=None,
                           conserve_nodes=None,
                           paths=None,
                           hyper=None,
                           norm=1e-3):
    assert isinstance(conserve_nodes, list), 'conserve nodes should be a list'
    assert isinstance(
        l_input_path, list
    ), 'inputs is expected to be a list of images for heterogeneous image size!'
    assert isinstance(paths, dict), 'paths should be a dict'
    assert isinstance(hyper, dict), 'hyper should be a dict'

    from mpi4py import MPI
    # prevent GPU
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
    communicator = MPI.COMM_WORLD
    rank = communicator.Get_rank()
    nb_process = communicator.Get_size()

    # optimize ckpt to pb for inference
    if rank == 0:

        for img in l_img_path:
            logger.debug(img)

        check_N_mkdir(paths['inference_dir'])
        freeze_ckpt_for_inference(paths=paths,
                                  hyper=hyper,
                                  conserve_nodes=conserve_nodes)
        optimize_pb_for_inference(paths=paths, conserve_nodes=conserve_nodes)
        reconstructor = reconstructor_V3_cls(
            image_size=load_img(l_input_path[0]).shape,
            z_len=len(l_input_path),
            nb_class=hyper['nb_classes'],
            maxp_times=hyperparams['maxp_times'])
        pbar1 = tqdm(total=len(l_input_path))

    # ************************************************************************************************ I'm a Barrier
    communicator.Barrier()

    # reconstruct volumn
    remaining = len(l_input_path)
    nb_img_per_rank = remaining // (nb_process - 1)
    rest_img = remaining % (nb_process - 1)
    print(nb_img_per_rank, rest_img, nb_process)

    if rank == 0:
        # start gathering batches from other rank
        s = MPI.Status()
        communicator.Probe(status=s)
        while remaining > 0:
            if s.tag == tag_compute:
                # receive outputs
                slice_id, out_batch = communicator.recv(tag=tag_compute)
                logger.debug(slice_id)
                reconstructor.write_slice(out_batch, slice_id)

                # progress
                remaining -= 1
                pbar1.update(1)

    else:
        try:
            if (rank - 1) < rest_img:
                start_id = (rank - 1) * (nb_img_per_rank + 1)
                id_list = np.arange(start_id, start_id + nb_img_per_rank + 1,
                                    1)
            else:
                start_id = (rank - 1) * nb_img_per_rank + rest_img
                id_list = np.arange(start_id, start_id + nb_img_per_rank, 1)

            logger.debug('{}: {}'.format(rank, id_list))
            _inference_recursive_V3(l_input_path=l_input_path,
                                    id_list=id_list,
                                    pb_path=paths['optimized_pb_path'],
                                    conserve_nodes=conserve_nodes,
                                    hyper=hyper,
                                    comm=communicator,
                                    maxp_times=hyperparams['maxp_times'],
                                    normalization=norm)
        except Exception as e:
            logger.error(e)
            MPI.MPI_abort(communicator)

    # ************************************************************************************************ I'm a Barrier
    communicator.Barrier()

    # save recon
    if rank == 0:
        recon = reconstructor.get_volume()
        for i in tqdm(range(len(l_input_path)), desc='writing data'):
            Image.fromarray(
                recon[i]).save(paths['inference_dir'] +
                               'step{}_{}.tif'.format(paths['step'], i))