コード例 #1
0
ファイル: predict.py プロジェクト: mpriessner/pytorch_fnet
def get_dataset(args: argparse.Namespace) -> torch.utils.data.Dataset:
    """Returns dataset.

    Returns
    -------
    torch.utils.data.Dataset
        Dataset object.

    """
    if sum([args.dataset is not None, args.path_tif is not None]) != 1:
        raise ValueError("Must specify one input source type")
    if args.dataset is not None:
        ds_fn = str_to_object(args.dataset)
        if not isinstance(ds_fn, Callable):
            raise ValueError(f"{args.dataset} must be callable")
        return ds_fn(**args.dataset_kwargs)
    if args.path_tif is not None:
        if not os.path.exists(args.path_tif):
            raise ValueError(f"Path does not exists: {args.path_tif}")
        paths_tif = [args.path_tif]
        if os.path.isdir(args.path_tif):
            paths_tif = files_from_dir(args.path_tif)
        ds = TiffDataset(
            dataframe=pd.DataFrame({
                "path_bf": paths_tif,
                "path_target": None
            }),
            transform_signal=[norm_around_center],
            transform_target=[norm_around_center],
            col_signal="path_bf",
        )
        return ds
    raise NotImplementedError
コード例 #2
0
    def __init__(
        self,
        betas=(0.5, 0.999),
        criterion_class='fnet.losses.WeightedMSE',
        init_weights=True,
        lr=0.001,
        nn_class='fnet.nn_modules.fnet_nn_3d.Net',
        nn_kwargs={},
        scheduler=None,
        weight_decay=0,
        gpu_ids=-1,
    ):
        self.betas = betas
        self.criterion = str_to_object(criterion_class)()
        self.gpu_ids = [gpu_ids] if isinstance(gpu_ids, int) else gpu_ids
        self.init_weights = init_weights
        self.lr = lr
        self.nn_class = nn_class
        self.nn_kwargs = nn_kwargs
        self.scheduler = scheduler
        self.weight_decay = weight_decay

        self.count_iter = 0
        self.device = (torch.device('cuda', self.gpu_ids[0])
                       if self.gpu_ids[0] >= 0 else torch.device('cpu'))
        self.optimizer = None
        self._init_model()
        self.fnet_model_kwargs, self.fnet_model_posargs = get_args()
        self.fnet_model_kwargs.pop('self')
コード例 #3
0
 def _init_model(self):
     self.net = str_to_object(self.nn_class)(**self.nn_kwargs)
     if self.init_weights:
         self.net.apply(_weights_init)
     self.net.to(self.device)
     self.optimizer = torch.optim.Adam(
         get_per_param_options(self.net, wd=self.weight_decay),
         lr=self.lr,
         betas=self.betas,
     )
     if self.scheduler is not None:
         if self.scheduler[0] == 'snapshot':
             period = self.scheduler[1]
             self.scheduler = torch.optim.lr_scheduler.LambdaLR(
                 self.optimizer,
                 lambda x: (0.01 + (1 - 0.01) *
                            (0.5 + 0.5 * math.cos(math.pi *
                                                  (x % period) / period))),
             )
         elif self.scheduler[0] == 'step':
             step_size = self.scheduler[1]
             self.scheduler = torch.optim.lr_scheduler.StepLR(
                 self.optimizer, step_size)
         else:
             raise NotImplementedError
コード例 #4
0
def get_bpds_train(args: argparse.Namespace) -> BufferedPatchDataset:
    """Creates data provider for training."""
    ds_fn = str_to_object(args.dataset_train)
    if not isinstance(ds_fn, Callable):
        raise ValueError('Dataset function should be Callable')
    ds = ds_fn(**args.dataset_train_kwargs)
    return BufferedPatchDataset(dataset=ds, **args.bpds_kwargs)
コード例 #5
0
ファイル: predict.py プロジェクト: hsiaoyi0504/pytorch_fnet
def main(args: Optional[argparse.Namespace] = None) -> None:
    """Predicts using model."""
    if args is None:
        parser = argparse.ArgumentParser()
        add_parser_arguments(parser)
        args = parser.parse_args()
    if args.json and not args.json.exists():
        save_default_predict_options(args.json)
        return
    load_from_json(args)
    metric = str_to_object(args.metric)
    dataset = get_dataset(args)
    entries = []
    model = None
    indices = get_indices(args, dataset)
    for count, idx in enumerate(indices, 1):
        logger.info(f"Processing: {idx:3d} ({count}/{len(indices)})")
        entry = {}
        entry["index"] = idx
        signal, target = item_from_dataset(dataset, idx)
        if not args.no_signal:
            entry["path_signal"] = save_tif(f"{idx}_signal.tif",
                                            signal.numpy()[0, ],
                                            args.path_save_dir)
        if not args.no_target and target is not None:
            entry["path_target"] = save_tif(f"{idx}_target.tif",
                                            target.numpy()[0, ],
                                            args.path_save_dir)
        for path_model_dir in args.path_model_dir:
            if model is None or len(args.path_model_dir) > 1:
                model_def = parse_model(path_model_dir)
                model = load_model(model_def["path"], no_optim=True)
                model.to_gpu(args.gpu_ids)
                logger.info(f'Loaded model: {model_def["name"]}')
            prediction = model.predict_piecewise(
                signal, tta=("no_tta" not in model_def["options"]))
            #             signal = to_numpy(signal)
            #             print(signal.shape)
            #             print(signal.shape[1])

            #             network = model.net
            #             network.eval()
            #             with torch.no_grad():
            #                 prediction = predict_on_zslice_tiles(network, signal)

            evaluation = metric(target, prediction)
            entry[args.metric + f'.{model_def["name"]}'] = evaluation
            if not args.no_prediction:
                for idx_c in range(prediction.size()[0]):
                    tag = f'prediction_c{idx_c}.{model_def["name"]}'
                    pred_c = prediction.numpy()[idx_c, ]
                    entry[f"path_{tag}"] = save_tif(f"{idx}_{tag}.tif", pred_c,
                                                    args.path_save_dir)
        entries.append(entry)
        save_predictions_csv(
            path_csv=os.path.join(args.path_save_dir, "predictions.csv"),
            pred_records=entries,
            dataset=dataset,
        )
    save_args_as_json(args.path_save_dir, args)
コード例 #6
0
def main(args: Optional[argparse.Namespace] = None) -> None:
    """Predicts using model."""
    if args is None:
        parser = argparse.ArgumentParser()
        add_parser_arguments(parser)
        args = parser.parse_args()
    metric = str_to_object(args.metric)
    dataset = get_dataset(args)
    index_name = dataset.df.index.name or 'index'
    if not os.path.exists(args.path_save_dir):
        os.makedirs(args.path_save_dir)
    entries = []
    model = None
    indices = (args.idx_sel if args.idx_sel is not None else
               dataset.df.index)[:args.n_images if args.n_images > 0 else None]
    for count, idx in enumerate(indices, 1):
        print(f'Processing: {idx:3d} ({count}/{len(indices)})')
        entry = {}
        entry[index_name] = idx
        data = dataset.loc[idx]
        signal = data[0]
        target = data[1] if len(data) > 1 else None
        if not args.no_signal:
            entry['path_signal'] = save_tif(f'{idx}_signal.tif',
                                            signal.numpy()[0, ],
                                            args.path_save_dir)
        if not args.no_target and target is not None:
            entry['path_target'] = save_tif(f'{idx}_target.tif',
                                            target.numpy()[0, ],
                                            args.path_save_dir)
        for path_model_dir in args.path_model_dir:
            if model is None or len(args.path_model_dir) > 1:
                model_def = parse_model(path_model_dir)
                model = load_model(model_def['path'], no_optim=True)
                model.to_gpu(args.gpu_ids)
                print('Predicting with:', model_def['name'])
            prediction = model.predict_piecewise(
                signal,
                tta=('no_tta' not in model_def['options']),
            )
            if args.add_sigmoid:
                prediction = torch.nn.functional.sigmoid(prediction)
            evaluation = metric(target, prediction)
            entry[args.metric + f'.{model_def["name"]}'] = evaluation
            if not args.no_prediction and prediction is not None:
                for idx_c in range(prediction.size()[0]):
                    tag = f'prediction_c{idx_c}.{model_def["name"]}'
                    pred_c = prediction.numpy()[idx_c, ]
                    entry[f'path_{tag}'] = save_tif(f'{idx}_{tag}.tif', pred_c,
                                                    args.path_save_dir)
        entries.append(entry)
        if ((count % 8) == 0) or (idx == indices[-1]):
            save_csv(
                os.path.join(args.path_save_dir, 'predictions.csv'),
                pd.DataFrame(entries).set_index(index_name),
            )
    save_args_as_json(args.path_save_dir, args)
コード例 #7
0
def get_bpds_val(args: argparse.Namespace) -> Optional[BufferedPatchDataset]:
    """Creates data provider for validation."""
    if args.dataset_val is None:
        return None
    bpds_kwargs = copy.deepcopy(args.bpds_kwargs)
    ds_fn = str_to_object(args.dataset_val)
    if not isinstance(ds_fn, Callable):
        raise ValueError('Dataset function should be Callable')
    ds = ds_fn(**args.dataset_val_kwargs)
    bpds_kwargs['buffer_size'] = min(4, len(ds))
    bpds_kwargs['buffer_switch_interval'] = -1
    return BufferedPatchDataset(dataset=ds, **bpds_kwargs)
コード例 #8
0
ファイル: predict.py プロジェクト: vallurumk/pytorch_fnet
def main(args: Optional[argparse.Namespace] = None) -> None:
    """Predicts using model."""
    if args is None:
        parser = argparse.ArgumentParser()
        add_parser_arguments(parser)
        args = parser.parse_args()
    if args.json and not args.json.exists():
        save_default_predict_options(args.json)
        return
    load_from_json(args)
    metric = str_to_object(args.metric)
    dataset = get_dataset(args)
    entries = []
    model = None
    indices = get_indices(args, dataset)
    for count, idx in enumerate(indices, 1):
        logger.info(f'Processing: {idx:3d} ({count}/{len(indices)})')
        entry = {}
        entry['index'] = idx
        signal, target = item_from_dataset(dataset, idx)
        if not args.no_signal:
            entry['path_signal'] = save_tif(f'{idx}_signal.tif',
                                            signal.numpy()[0, ],
                                            args.path_save_dir)
        if not args.no_target and target is not None:
            entry['path_target'] = save_tif(f'{idx}_target.tif',
                                            target.numpy()[0, ],
                                            args.path_save_dir)
        for path_model_dir in args.path_model_dir:
            if model is None or len(args.path_model_dir) > 1:
                model_def = parse_model(path_model_dir)
                model = load_model(model_def['path'], no_optim=True)
                model.to_gpu(args.gpu_ids)
                logger.info(f'Loaded model: {model_def["name"]}')
            prediction = model.predict_piecewise(
                signal,
                tta=('no_tta' not in model_def['options']),
            )
            evaluation = metric(target, prediction)
            entry[args.metric + f'.{model_def["name"]}'] = evaluation
            if not args.no_prediction:
                for idx_c in range(prediction.size()[0]):
                    tag = f'prediction_c{idx_c}.{model_def["name"]}'
                    pred_c = prediction.numpy()[idx_c, ]
                    entry[f'path_{tag}'] = save_tif(f'{idx}_{tag}.tif', pred_c,
                                                    args.path_save_dir)
        entries.append(entry)
        save_predictions_csv(
            path_csv=os.path.join(args.path_save_dir, 'predictions.csv'),
            pred_records=entries,
            dataset=dataset,
        )
    save_args_as_json(args.path_save_dir, args)
コード例 #9
0
def get_dataloader_train(
        args: argparse.Namespace,
        n_iter_remaining: int,
) -> torch.utils.data.DataLoader:
    """Creates DataLoader for training."""
    bpds_kwargs = copy.deepcopy(args.bpds_kwargs)
    bpds_kwargs['npatches'] = n_iter_remaining*args.batch_size
    ds_fn = str_to_object(args.dataset_train)
    if not isinstance(ds_fn, Callable):
        raise ValueError('Dataset function should be Callable')
    ds = ds_fn(**args.dataset_train_kwargs)
    bpds = fnet.data.BufferedPatchDataset(dataset=ds, **bpds_kwargs)
    dataloader = torch.utils.data.DataLoader(bpds, batch_size=args.batch_size)
    return dataloader
コード例 #10
0
def get_dataloader_val(
        args: argparse.Namespace) -> Optional[torch.utils.data.DataLoader]:
    """Creates DataLoader for validation."""
    if args.dataset_val is None:
        return None
    bpds_kwargs = copy.deepcopy(args.bpds_kwargs)
    ds_fn = str_to_object(args.dataset_val)
    if not isinstance(ds_fn, Callable):
        raise ValueError('Dataset function should be Callable')
    ds = ds_fn(**args.dataset_val_kwargs)
    bpds_kwargs['buffer_size'] = min(4, len(ds))
    bpds_kwargs['buffer_switch_frequency'] = -1
    bpds_kwargs['npatches'] = 16 * args.batch_size
    bpds = fnet.data.BufferedPatchDataset(dataset=ds, **bpds_kwargs)
    dataloader = torch.utils.data.DataLoader(bpds, batch_size=args.batch_size)
    return dataloader
コード例 #11
0
def test_str_to_object():
    exp = [_dummy, random.randrange]
    for idx_s, as_str in enumerate(['_dummy', 'random.randrange']):
        obj = str_to_object(as_str)
        assert obj is exp[idx_s], f'{obj} is not {exp[idx_s]}'
コード例 #12
0
ファイル: test_utils.py プロジェクト: wayne980/pytorch_fnet
def test_str_to_object():
    """Test string-to-object conversion."""
    exp = [_dummy, random.randrange]
    for idx_s, as_str in enumerate(['_dummy', 'random.randrange']):
        obj = str_to_object(as_str)
        assert obj is exp[idx_s], f'{obj} is not {exp[idx_s]}'