Beispiel #1
0
 def network_function():
     return NetworkFusion(3,
                          N_CLASSES,
                          get_conv(_config["network"]["backend_conv"]),
                          get_search(
                              _config["network"]["backend_search"]),
                          config=_config)
Beispiel #2
0
 def network_function():
     return Network(
         input_channels,
         N_LABELS,
         get_conv(_config["network"]["backend_conv"]),
         get_search(_config["network"]["backend_search"]),
     )
Beispiel #3
0
def training_thread(acont: ArgsContainer):
    torch.cuda.empty_cache()
    lr = 1e-3
    lr_stepsize = 10000
    lr_dec = 0.995
    max_steps = int(acont.max_step_size / acont.batch_size)

    torch.manual_seed(acont.random_seed)
    np.random.seed(acont.random_seed)
    random.seed(acont.random_seed)

    if acont.use_cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    lcp_flag = False
    # load model
    if acont.architecture == 'lcp' or acont.model == 'ConvAdaptSeg':
        kwargs = {}
        if acont.model == 'ConvAdaptSeg':
            kwargs = dict(kernel_num=acont.pl, architecture=acont.architecture, activation=acont.act, norm=acont.norm_type)
        conv = dict(layer=acont.conv[0], kernel_separation=acont.conv[1])
        model = ConvAdaptSeg(acont.input_channels, acont.class_num, get_conv(conv), get_search(acont.search), **kwargs)
        lcp_flag = True
    elif acont.use_big:
        model = SegBig(acont.input_channels, acont.class_num, trs=acont.track_running_stats, dropout=acont.dropout,
                       use_bias=acont.use_bias, norm_type=acont.norm_type, use_norm=acont.use_norm,
                       kernel_size=acont.kernel_size, neighbor_nums=acont.neighbor_nums, reductions=acont.reductions,
                       first_layer=acont.first_layer, padding=acont.padding, nn_center=acont.nn_center,
                       centroids=acont.centroids, pl=acont.pl, normalize=acont.cp_norm)
    else:
        model = SegAdapt(acont.input_channels, acont.class_num, architecture=acont.architecture,
                         trs=acont.track_running_stats, dropout=acont.dropout, use_bias=acont.use_bias,
                         norm_type=acont.norm_type, kernel_size=acont.kernel_size, padding=acont.padding,
                         nn_center=acont.nn_center, centroids=acont.centroids, kernel_num=acont.pl,
                         normalize=acont.cp_norm, act=acont.act)

    batch_size = acont.batch_size

    train_transforms = clouds.Compose(acont.train_transforms)
    train_ds = TorchHandler(data_path=acont.train_path, sample_num=acont.sample_num, nclasses=acont.class_num,
                            feat_dim=acont.input_channels, density_mode=acont.density_mode,
                            ctx_size=acont.chunk_size, bio_density=acont.bio_density,
                            tech_density=acont.tech_density, transform=train_transforms,
                            obj_feats=acont.features, label_mappings=acont.label_mappings,
                            hybrid_mode=acont.hybrid_mode, splitting_redundancy=acont.splitting_redundancy,
                            label_remove=acont.label_remove, sampling=acont.sampling, padding=acont.padding,
                            split_on_demand=acont.split_on_demand, split_jitter=acont.split_jitter,
                            epoch_size=acont.epoch_size, workers=acont.workers, voxel_sizes=acont.voxel_sizes,
                            ssd_exclude=acont.ssd_exclude, ssd_include=acont.ssd_include,
                            ssd_labels=acont.ssd_labels, exclude_borders=acont.exclude_borders,
                            rebalance=acont.rebalance, extend_no_pred=acont.extend_no_pred)

    if acont.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    elif acont.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=0.5e-5)
    else:
        raise ValueError('Unknown optimizer')

    if acont.scheduler == 'steplr':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, lr_stepsize, lr_dec)
    elif acont.scheduler == 'cosannwarm':
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5000, T_mult=2)
    else:
        raise ValueError('Unknown scheduler')

    # calculate class weights if necessary
    weights = None
    if acont.class_weights is not None:
        weights = torch.from_numpy(acont.class_weights).float()

    criterion = torch.nn.CrossEntropyLoss(weight=weights)
    if acont.use_cuda:
        criterion.cuda()

    if acont.use_val:
        val_path = acont.val_path
    else:
        val_path = None

    trainer = Trainer3d(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        device=device,
        train_dataset=train_ds,
        v_path=val_path,
        val_freq=acont.val_freq,
        val_red=acont.val_iter,
        channel_num=acont.input_channels,
        batchsize=batch_size,
        num_workers=4,
        save_root=acont.save_root,
        exp_name=acont.name,
        num_classes=acont.class_num,
        schedulers={"lr": scheduler},
        target_names=acont.target_names,
        stop_epoch=acont.stop_epoch,
        enable_tensorboard=False,
        lcp_flag=lcp_flag,
    )
    # Archiving training script, src folder, env info
    Backup(script_path=__file__, save_path=trainer.save_path).archive_backup()
    acont.save2pkl(trainer.save_path + '/argscont.pkl')
    with open(trainer.save_path + '/argscont.txt', 'w') as f:
        f.write(str(acont.attr_dict))
    f.close()

    trainer.run(max_steps)
Beispiel #4
0
def generate_predictions_with_model(argscont: ArgsContainer,
                                    model_path: str,
                                    cell_path: str,
                                    out_path: str,
                                    prediction_redundancy: int = 1,
                                    batch_size: int = -1,
                                    chunk_redundancy: int = -1,
                                    force_split: bool = False,
                                    training_seed: bool = False,
                                    label_mappings: List[Tuple[int,
                                                               int]] = None,
                                    label_remove: List[int] = None,
                                    border_exclusion: int = 0,
                                    state_dict: str = None,
                                    model=None,
                                    **args):
    """
    Can be used to generate predictions for multiple files using a specific model (either passed as path to state_dict or as pre-loaded model).

    Args:
        argscont: argument container for current model.
        model_path: path to model state dict.
        cell_path: path to cells used for prediction.
        out_path: path to folder where predictions of this model should get saved.
        prediction_redundancy: number of times each cell should be processed (using the same chunks but different points due to random sampling).
        batch_size: batch size, if -1 this defaults to the batch size used during training.
        chunk_redundancy: number of times each cell should get splitted into a complete chunk set (including different chunks each time).
        force_split: split cells even if cached split information exists.
        training_seed: use random seed from training.
        label_mappings: List of tuples like (from, to) where 'from' is label which should get mapped to 'to'.
            Defaults to label_mappings from training or to val_label_mappings of ArgsContainer.
        label_remove: List of labels to remove from the cells. Defaults to label_remove from training or to val_label_remove of ArgsContainer.
        border_exclusion: nm distance which defines how much of the chunk borders should be excluded from predictions.
        state_dict: state dict holding model for prediction.
        model: loaded model to use for prediction.
    """
    if os.path.exists(out_path):
        print(f"{out_path} already exists. Skipping...")
        return

    if training_seed:
        torch.manual_seed(argscont.random_seed)
        np.random.seed(argscont.random_seed)
        random.seed(argscont.random_seed)

    if argscont.use_cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    lcp_flag = False
    if model is not None:
        model = model
        if isinstance(model, ConvAdaptSeg):
            lcp_flag = True
    else:
        # load model
        if argscont.architecture == 'lcp' or argscont.model == 'ConvAdaptSeg':
            kwargs = {}
            if argscont.model == 'ConvAdaptSeg':
                kwargs = dict(kernel_num=argscont.pl,
                              architecture=argscont.architecture,
                              activation=argscont.act,
                              norm=argscont.norm_type)
            conv = dict(layer=argscont.conv[0],
                        kernel_separation=argscont.conv[1])
            model = ConvAdaptSeg(argscont.input_channels, argscont.class_num,
                                 get_conv(conv), get_search(argscont.search),
                                 **kwargs)
            lcp_flag = True
        elif argscont.use_big:
            model = SegBig(argscont.input_channels,
                           argscont.class_num,
                           trs=argscont.track_running_stats,
                           dropout=argscont.dropout,
                           use_bias=argscont.use_bias,
                           norm_type=argscont.norm_type,
                           use_norm=argscont.use_norm,
                           kernel_size=argscont.kernel_size,
                           neighbor_nums=argscont.neighbor_nums,
                           reductions=argscont.reductions,
                           first_layer=argscont.first_layer,
                           padding=argscont.padding,
                           nn_center=argscont.nn_center,
                           centroids=argscont.centroids,
                           pl=argscont.pl,
                           normalize=argscont.cp_norm)
        else:
            model = SegAdapt(argscont.input_channels,
                             argscont.class_num,
                             architecture=argscont.architecture,
                             trs=argscont.track_running_stats,
                             dropout=argscont.dropout,
                             use_bias=argscont.use_bias,
                             norm_type=argscont.norm_type,
                             kernel_size=argscont.kernel_size,
                             padding=argscont.padding,
                             nn_center=argscont.nn_center,
                             centroids=argscont.centroids,
                             kernel_num=argscont.pl,
                             normalize=argscont.cp_norm,
                             act=argscont.act)
        try:
            full = torch.load(model_path + state_dict)
            model.load_state_dict(full)
        except RuntimeError:
            model.load_state_dict(full['model_state_dict'])
        model.to(device)
        model.eval()

    transforms = clouds.Compose(argscont.val_transforms)
    if chunk_redundancy == -1:
        chunk_redundancy = argscont.splitting_redundancy
    if batch_size == -1:
        batch_size = argscont.batch_size
    if label_remove is None:
        if argscont.val_label_remove is not None:
            label_remove = argscont.val_label_remove
        else:
            label_remove = argscont.label_remove
    if label_mappings is None:
        if argscont.val_label_mappings is not None:
            label_mappings = argscont.val_label_mappings
        else:
            label_mappings = argscont.label_mappings

    torch_handler = TorchHandler(cell_path,
                                 argscont.sample_num,
                                 argscont.class_num,
                                 density_mode=argscont.density_mode,
                                 bio_density=argscont.bio_density,
                                 tech_density=argscont.tech_density,
                                 transform=transforms,
                                 specific=True,
                                 obj_feats=argscont.features,
                                 ctx_size=argscont.chunk_size,
                                 label_mappings=label_mappings,
                                 hybrid_mode=argscont.hybrid_mode,
                                 feat_dim=argscont.input_channels,
                                 splitting_redundancy=chunk_redundancy,
                                 label_remove=label_remove,
                                 sampling=argscont.sampling,
                                 force_split=force_split,
                                 padding=argscont.padding,
                                 exclude_borders=border_exclusion)
    prediction_mapper = PredictionMapper(cell_path,
                                         out_path,
                                         torch_handler.splitfile,
                                         label_remove=label_remove,
                                         hybrid_mode=argscont.hybrid_mode)

    obj = None
    obj_names = torch_handler.obj_names.copy()
    for obj in torch_handler.obj_names:
        if os.path.exists(out_path + obj + '_preds.pkl'):
            print(obj + " has already been processed. Skipping...")
            obj_names.remove(obj)
            continue
        if torch_handler.get_obj_length(obj) == 0:
            print(obj + " has no chunks to process. Skipping...")
            obj_names.remove(obj)
            continue
        print(f"Processing {obj}")
        predict_cell(torch_handler,
                     obj,
                     batch_size,
                     argscont.sample_num,
                     prediction_redundancy,
                     device,
                     model,
                     prediction_mapper,
                     argscont.input_channels,
                     point_subsampling=argscont.sampling,
                     lcp_flag=lcp_flag)
    if obj is not None:
        prediction_mapper.save_prediction()
    else:
        return
    argscont.save2pkl(out_path + 'argscont.pkl')
    del model
    torch.cuda.empty_cache()