Ejemplo n.º 1
0
def main(args=None):
    args, no_config_file = get_args(args)
    setup_log(args.verbosity)
    logger = logging.getLogger(__name__)
    try:
        # set random seeds for reproducibility
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)

        # since prediction only uses one gpu (at most), make the batch size small enough to fit
        if args.n_gpus > 1: args.batch_size = args.batch_size // args.n_gpus

        learner = Learner.predict_setup(args)

        # determine how many samples we will use in prediction
        nsyn = args.monte_carlo or 1

        # get relevant prediction directories and determine extension
        predict_dir = args.predict_dir or args.valid_source_dir
        output_dir = args.predict_out or os.getcwd() + '/syn_'
        ext = determine_ext(predict_dir[0])

        # setup and start prediction loop
        axis = args.sample_axis or 0
        if axis < 0 or axis > 2 and not isinstance(axis, int):
            raise ValueError(
                'sample_axis must be an integer between 0 and 2 inclusive')
        n_imgs = len(glob_imgs(predict_dir[0], ))
        if n_imgs == 0:
            raise SynthNNError(
                'Prediction directory does not contain valid images.')
        if any([len(glob_imgs(pd, ext)) != n_imgs for pd in predict_dir]):
            raise SynthNNError(
                'Number of images in prediction directories must have an equal number of images in each '
                'directory (e.g., so that img_t1_1 aligns with img_t2_1 etc. for multimodal synth)'
            )
        predict_fns = zip(*[glob_imgs(pd, ext) for pd in predict_dir])

        if args.net3d and args.patch_size > 0 and args.calc_var:
            raise SynthNNError(
                'Patch-based 3D variance calculation not currently supported.')

        for k, fn in enumerate(predict_fns):
            _, base, ext = split_filename(fn[0])
            logger.info(
                f'Starting synthesis of image: {base} ({k+1}/{n_imgs})')
            out_imgs = learner.predict(fn, nsyn, args.temperature_map,
                                       args.calc_var)
            for i, oin in enumerate(out_imgs):
                out_fn = output_dir + f'{k}_{i}{ext}'
                if hasattr(oin, 'to_filename'):
                    oin.to_filename(out_fn)
                else:
                    oin.save(out_fn)
                logger.info(f'Finished synthesis. Saved as: {out_fn}')

        return 0
    except Exception as e:
        logger.exception(e)
        return 1
Ejemplo n.º 2
0
 def setUp(self):
     wd = os.path.dirname(os.path.abspath(__file__))
     self.nii_dir = os.path.join(wd, 'test_data', 'nii')
     self.mask_dir = os.path.join(wd, 'test_data', 'masks')
     self.tif_dir = os.path.join(wd, 'test_data', 'tif')
     self.png_dir = os.path.join(wd, 'test_data', 'png')
     self.out_dir = tempfile.mkdtemp()
     os.mkdir(os.path.join(self.out_dir, 'models'))
     self.train_dir = os.path.join(self.out_dir, 'imgs')
     os.mkdir(self.train_dir)
     os.mkdir(os.path.join(self.train_dir, 'mask'))
     os.mkdir(os.path.join(self.train_dir, 'tif'))
     os.mkdir(os.path.join(self.train_dir, 'png'))
     nii = glob_imgs(self.nii_dir)[0]
     msk = glob_imgs(self.mask_dir)[0]
     tif = os.path.join(self.tif_dir, 'test.tif')
     png = os.path.join(self.png_dir, 'test.png')
     path, base, ext = split_filename(nii)
     for i in range(8):
         shutil.copy(nii, os.path.join(self.train_dir, base + str(i) + ext))
         shutil.copy(
             msk, os.path.join(self.train_dir, 'mask', base + str(i) + ext))
         shutil.copy(
             tif, os.path.join(self.train_dir, 'tif',
                               base + str(i) + '.tif'))
         shutil.copy(
             png, os.path.join(self.train_dir, 'png',
                               base + str(i) + '.png'))
     self.train_args = f'-s {self.train_dir} -t {self.train_dir}'.split()
     self.predict_args = f'-s {self.train_dir} -o {self.out_dir}/test'.split(
     )
     self.jsonfn = f'{self.out_dir}/test.json'
Ejemplo n.º 3
0
def determine_ext(d):
    """ given a directory determine if it contains supported images """
    exts = ('*.nii*', '*.tif*', '*.png')
    contains = [len(glob_imgs(d, ext)) > 0 for ext in exts]
    if sum(contains) == 0:
        raise SynthtorchError(f'Directory {d} contains no supported images.')
    if sum(contains) > 1:
        raise SynthtorchError(f'Directory {d} contains more than two types of supported images, '
                           f'remove unwanted images from directory')
    ext = [e for c, e in zip(contains, exts) if c][0]
    return ext