def network_function(): return NetworkFusion(3, N_CLASSES, get_conv(_config["network"]["backend_conv"]), get_search( _config["network"]["backend_search"]), config=_config)
def network_function(): return Network( input_channels, N_LABELS, get_conv(_config["network"]["backend_conv"]), get_search(_config["network"]["backend_search"]), )
def training_thread(acont: ArgsContainer): torch.cuda.empty_cache() lr = 1e-3 lr_stepsize = 10000 lr_dec = 0.995 max_steps = int(acont.max_step_size / acont.batch_size) torch.manual_seed(acont.random_seed) np.random.seed(acont.random_seed) random.seed(acont.random_seed) if acont.use_cuda: device = torch.device('cuda') else: device = torch.device('cpu') lcp_flag = False # load model if acont.architecture == 'lcp' or acont.model == 'ConvAdaptSeg': kwargs = {} if acont.model == 'ConvAdaptSeg': kwargs = dict(kernel_num=acont.pl, architecture=acont.architecture, activation=acont.act, norm=acont.norm_type) conv = dict(layer=acont.conv[0], kernel_separation=acont.conv[1]) model = ConvAdaptSeg(acont.input_channels, acont.class_num, get_conv(conv), get_search(acont.search), **kwargs) lcp_flag = True elif acont.use_big: model = SegBig(acont.input_channels, acont.class_num, trs=acont.track_running_stats, dropout=acont.dropout, use_bias=acont.use_bias, norm_type=acont.norm_type, use_norm=acont.use_norm, kernel_size=acont.kernel_size, neighbor_nums=acont.neighbor_nums, reductions=acont.reductions, first_layer=acont.first_layer, padding=acont.padding, nn_center=acont.nn_center, centroids=acont.centroids, pl=acont.pl, normalize=acont.cp_norm) else: model = SegAdapt(acont.input_channels, acont.class_num, architecture=acont.architecture, trs=acont.track_running_stats, dropout=acont.dropout, use_bias=acont.use_bias, norm_type=acont.norm_type, kernel_size=acont.kernel_size, padding=acont.padding, nn_center=acont.nn_center, centroids=acont.centroids, kernel_num=acont.pl, normalize=acont.cp_norm, act=acont.act) batch_size = acont.batch_size train_transforms = clouds.Compose(acont.train_transforms) train_ds = TorchHandler(data_path=acont.train_path, sample_num=acont.sample_num, nclasses=acont.class_num, feat_dim=acont.input_channels, density_mode=acont.density_mode, ctx_size=acont.chunk_size, bio_density=acont.bio_density, tech_density=acont.tech_density, transform=train_transforms, obj_feats=acont.features, label_mappings=acont.label_mappings, hybrid_mode=acont.hybrid_mode, splitting_redundancy=acont.splitting_redundancy, label_remove=acont.label_remove, sampling=acont.sampling, padding=acont.padding, split_on_demand=acont.split_on_demand, split_jitter=acont.split_jitter, epoch_size=acont.epoch_size, workers=acont.workers, voxel_sizes=acont.voxel_sizes, ssd_exclude=acont.ssd_exclude, ssd_include=acont.ssd_include, ssd_labels=acont.ssd_labels, exclude_borders=acont.exclude_borders, rebalance=acont.rebalance, extend_no_pred=acont.extend_no_pred) if acont.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=lr) elif acont.optimizer == 'sgd': optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=0.5e-5) else: raise ValueError('Unknown optimizer') if acont.scheduler == 'steplr': scheduler = torch.optim.lr_scheduler.StepLR(optimizer, lr_stepsize, lr_dec) elif acont.scheduler == 'cosannwarm': scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5000, T_mult=2) else: raise ValueError('Unknown scheduler') # calculate class weights if necessary weights = None if acont.class_weights is not None: weights = torch.from_numpy(acont.class_weights).float() criterion = torch.nn.CrossEntropyLoss(weight=weights) if acont.use_cuda: criterion.cuda() if acont.use_val: val_path = acont.val_path else: val_path = None trainer = Trainer3d( model=model, criterion=criterion, optimizer=optimizer, device=device, train_dataset=train_ds, v_path=val_path, val_freq=acont.val_freq, val_red=acont.val_iter, channel_num=acont.input_channels, batchsize=batch_size, num_workers=4, save_root=acont.save_root, exp_name=acont.name, num_classes=acont.class_num, schedulers={"lr": scheduler}, target_names=acont.target_names, stop_epoch=acont.stop_epoch, enable_tensorboard=False, lcp_flag=lcp_flag, ) # Archiving training script, src folder, env info Backup(script_path=__file__, save_path=trainer.save_path).archive_backup() acont.save2pkl(trainer.save_path + '/argscont.pkl') with open(trainer.save_path + '/argscont.txt', 'w') as f: f.write(str(acont.attr_dict)) f.close() trainer.run(max_steps)
def generate_predictions_with_model(argscont: ArgsContainer, model_path: str, cell_path: str, out_path: str, prediction_redundancy: int = 1, batch_size: int = -1, chunk_redundancy: int = -1, force_split: bool = False, training_seed: bool = False, label_mappings: List[Tuple[int, int]] = None, label_remove: List[int] = None, border_exclusion: int = 0, state_dict: str = None, model=None, **args): """ Can be used to generate predictions for multiple files using a specific model (either passed as path to state_dict or as pre-loaded model). Args: argscont: argument container for current model. model_path: path to model state dict. cell_path: path to cells used for prediction. out_path: path to folder where predictions of this model should get saved. prediction_redundancy: number of times each cell should be processed (using the same chunks but different points due to random sampling). batch_size: batch size, if -1 this defaults to the batch size used during training. chunk_redundancy: number of times each cell should get splitted into a complete chunk set (including different chunks each time). force_split: split cells even if cached split information exists. training_seed: use random seed from training. label_mappings: List of tuples like (from, to) where 'from' is label which should get mapped to 'to'. Defaults to label_mappings from training or to val_label_mappings of ArgsContainer. label_remove: List of labels to remove from the cells. Defaults to label_remove from training or to val_label_remove of ArgsContainer. border_exclusion: nm distance which defines how much of the chunk borders should be excluded from predictions. state_dict: state dict holding model for prediction. model: loaded model to use for prediction. """ if os.path.exists(out_path): print(f"{out_path} already exists. Skipping...") return if training_seed: torch.manual_seed(argscont.random_seed) np.random.seed(argscont.random_seed) random.seed(argscont.random_seed) if argscont.use_cuda: device = torch.device('cuda') else: device = torch.device('cpu') lcp_flag = False if model is not None: model = model if isinstance(model, ConvAdaptSeg): lcp_flag = True else: # load model if argscont.architecture == 'lcp' or argscont.model == 'ConvAdaptSeg': kwargs = {} if argscont.model == 'ConvAdaptSeg': kwargs = dict(kernel_num=argscont.pl, architecture=argscont.architecture, activation=argscont.act, norm=argscont.norm_type) conv = dict(layer=argscont.conv[0], kernel_separation=argscont.conv[1]) model = ConvAdaptSeg(argscont.input_channels, argscont.class_num, get_conv(conv), get_search(argscont.search), **kwargs) lcp_flag = True elif argscont.use_big: model = SegBig(argscont.input_channels, argscont.class_num, trs=argscont.track_running_stats, dropout=argscont.dropout, use_bias=argscont.use_bias, norm_type=argscont.norm_type, use_norm=argscont.use_norm, kernel_size=argscont.kernel_size, neighbor_nums=argscont.neighbor_nums, reductions=argscont.reductions, first_layer=argscont.first_layer, padding=argscont.padding, nn_center=argscont.nn_center, centroids=argscont.centroids, pl=argscont.pl, normalize=argscont.cp_norm) else: model = SegAdapt(argscont.input_channels, argscont.class_num, architecture=argscont.architecture, trs=argscont.track_running_stats, dropout=argscont.dropout, use_bias=argscont.use_bias, norm_type=argscont.norm_type, kernel_size=argscont.kernel_size, padding=argscont.padding, nn_center=argscont.nn_center, centroids=argscont.centroids, kernel_num=argscont.pl, normalize=argscont.cp_norm, act=argscont.act) try: full = torch.load(model_path + state_dict) model.load_state_dict(full) except RuntimeError: model.load_state_dict(full['model_state_dict']) model.to(device) model.eval() transforms = clouds.Compose(argscont.val_transforms) if chunk_redundancy == -1: chunk_redundancy = argscont.splitting_redundancy if batch_size == -1: batch_size = argscont.batch_size if label_remove is None: if argscont.val_label_remove is not None: label_remove = argscont.val_label_remove else: label_remove = argscont.label_remove if label_mappings is None: if argscont.val_label_mappings is not None: label_mappings = argscont.val_label_mappings else: label_mappings = argscont.label_mappings torch_handler = TorchHandler(cell_path, argscont.sample_num, argscont.class_num, density_mode=argscont.density_mode, bio_density=argscont.bio_density, tech_density=argscont.tech_density, transform=transforms, specific=True, obj_feats=argscont.features, ctx_size=argscont.chunk_size, label_mappings=label_mappings, hybrid_mode=argscont.hybrid_mode, feat_dim=argscont.input_channels, splitting_redundancy=chunk_redundancy, label_remove=label_remove, sampling=argscont.sampling, force_split=force_split, padding=argscont.padding, exclude_borders=border_exclusion) prediction_mapper = PredictionMapper(cell_path, out_path, torch_handler.splitfile, label_remove=label_remove, hybrid_mode=argscont.hybrid_mode) obj = None obj_names = torch_handler.obj_names.copy() for obj in torch_handler.obj_names: if os.path.exists(out_path + obj + '_preds.pkl'): print(obj + " has already been processed. Skipping...") obj_names.remove(obj) continue if torch_handler.get_obj_length(obj) == 0: print(obj + " has no chunks to process. Skipping...") obj_names.remove(obj) continue print(f"Processing {obj}") predict_cell(torch_handler, obj, batch_size, argscont.sample_num, prediction_redundancy, device, model, prediction_mapper, argscont.input_channels, point_subsampling=argscont.sampling, lcp_flag=lcp_flag) if obj is not None: prediction_mapper.save_prediction() else: return argscont.save2pkl(out_path + 'argscont.pkl') del model torch.cuda.empty_cache()