Пример #1
0
def vol_prior_hack(
        *args,
        proc_vol_fn=None,
        proc_seg_fn=None,
        prior_type='location',  # file-static, file-gen, location
        prior_file=None,  # prior filename
        prior_feed='input',  # input or output
        patch_stride=1,
        patch_size=None,
        batch_size=1,
        collapse_2d=None,
        extract_slice=None,
        force_binary=False,
        nb_input_feats=1,
        verbose=False,
        vol_rand_seed=None,
        **kwargs):
    """
    
    """
    # prepare the vol_seg
    gen = vol_seg_hack(*args,
                       **kwargs,
                       proc_vol_fn=None,
                       proc_seg_fn=None,
                       collapse_2d=collapse_2d,
                       extract_slice=extract_slice,
                       force_binary=force_binary,
                       verbose=verbose,
                       patch_size=patch_size,
                       patch_stride=patch_stride,
                       batch_size=batch_size,
                       vol_rand_seed=vol_rand_seed,
                       nb_input_feats=nb_input_feats)

    # get prior
    if prior_type == 'location':
        prior_vol = nd.volsize2ndgrid(vol_size)
        prior_vol = np.transpose(prior_vol, [1, 2, 3, 0])
        prior_vol = np.expand_dims(prior_vol, axis=0)  # reshape for model

    elif prior_type == 'file':  # assumes a npz filename passed in prior_file
        with timer.Timer('loading prior', True):
            data = np.load(prior_file)
            prior_vol = data['prior'].astype('float16')
    else:  # assumes a volume
        with timer.Timer('astyping prior', verbose):
            prior_vol = prior_file
            if not (prior_vol.dtype == 'float16'):
                prior_vol = prior_vol.astype('float16')

    if force_binary:
        nb_labels = prior_vol.shape[-1]
        prior_vol[:, :, :, 1] = np.sum(prior_vol[:, :, :, 1:nb_labels], 3)
        prior_vol = np.delete(prior_vol, range(2, nb_labels), 3)

    nb_channels = prior_vol.shape[-1]

    if extract_slice is not None:
        if isinstance(extract_slice, int):
            prior_vol = prior_vol[:, :, extract_slice, np.newaxis, :]
        else:  # assume slices
            prior_vol = prior_vol[:, :, extract_slice, :]

    # get the prior to have the right volume [x, y, z, nb_channels]
    assert np.ndim(prior_vol) == 4 or np.ndim(
        prior_vol) == 3, "prior is the wrong size"

    # prior generator
    if patch_size is None:
        patch_size = prior_vol.shape[0:3]
    assert len(patch_size) == len(patch_stride)
    prior_gen = patch(
        prior_vol,
        [*patch_size, nb_channels],
        patch_stride=[*patch_stride, nb_channels],
        batch_size=batch_size,
        collapse_2d=collapse_2d,
        keep_vol_size=True,
        infinite=True,
        #variable_batch_size=True,  # this
        nb_labels_reshape=0)
    # assert next(prior_gen) is None, "bad prior gen setup"

    # generator loop
    while 1:

        # generate input and output volumes
        input_vol = next(gen)

        if verbose and np.all(input_vol.flat == 0):
            print("all entries are 0")

        # generate prior batch
        # with timer.Timer("with send?"):
        # prior_batch = prior_gen.send(input_vol.shape[0])
        prior_batch = next(prior_gen)

        if prior_feed == 'input':
            yield ([input_vol, prior_batch], input_vol)
        else:
            assert prior_feed == 'output'
            yield (input_vol, [input_vol, prior_batch])
Пример #2
0
def add_prior(
        gen,
        proc_vol_fn=None,
        proc_seg_fn=None,
        prior_type='location',  # file-static, file-gen, location
        prior_file=None,  # prior filename
        prior_feed='input',  # input or output
        patch_stride=1,
        patch_size=None,
        batch_size=1,
        collapse_2d=None,
        extract_slice=None,
        force_binary=False,
        verbose=False,
        patch_rand=False,
        patch_rand_seed=None):
    """
    #
    # add a prior generator to a given generator
    # with the number of patches in batch matching output of gen
    """

    # get prior
    if prior_type == 'location':
        prior_vol = nd.volsize2ndgrid(vol_size)
        prior_vol = np.transpose(prior_vol, [1, 2, 3, 0])
        prior_vol = np.expand_dims(prior_vol, axis=0)  # reshape for model

    elif prior_type == 'file':  # assumes a npz filename passed in prior_file
        with timer.Timer('loading prior', True):
            data = np.load(prior_file)
            prior_vol = data['prior'].astype('float16')

    else:  # assumes a volume
        with timer.Timer('loading prior', True):
            prior_vol = prior_file.astype('float16')

    if force_binary:
        nb_labels = prior_vol.shape[-1]
        prior_vol[:, :, :, 1] = np.sum(prior_vol[:, :, :, 1:nb_labels], 3)
        prior_vol = np.delete(prior_vol, range(2, nb_labels), 3)

    nb_channels = prior_vol.shape[-1]

    if extract_slice is not None:
        if isinstance(extract_slice, int):
            prior_vol = prior_vol[:, :, extract_slice, np.newaxis, :]
        else:  # assume slices
            prior_vol = prior_vol[:, :, extract_slice, :]

    # get the prior to have the right volume [x, y, z, nb_channels]
    assert np.ndim(prior_vol) == 4 or np.ndim(
        prior_vol) == 3, "prior is the wrong size"

    # prior generator
    if patch_size is None:
        patch_size = prior_vol.shape[0:3]
    assert len(patch_size) == len(patch_stride)
    prior_gen = patch(prior_vol, [*patch_size, nb_channels],
                      patch_stride=[*patch_stride, nb_channels],
                      batch_size=batch_size,
                      collapse_2d=collapse_2d,
                      keep_vol_size=True,
                      infinite=True,
                      patch_rand=patch_rand,
                      patch_rand_seed=patch_rand_seed,
                      variable_batch_size=True,
                      nb_labels_reshape=0)
    assert next(prior_gen) is None, "bad prior gen setup"

    # generator loop
    while 1:

        # generate input and output volumes
        gen_sample = next(gen)

        # generate prior batch
        gs_sample = _get_shape(gen_sample)
        prior_batch = prior_gen.send(gs_sample)

        yield (gen_sample, prior_batch)