def _initialize_model(self, model: BaseModel, weight_name): if not self._checkpoint.is_empty: state_dict = self._checkpoint.get_state_dict(weight_name) model.load_state_dict(state_dict) if self._resume: model.optimizer = self._checkpoint.get_optimizer(model) model.schedulers = self._checkpoint.get_schedulers(model)
def test_epoch( model: BaseModel, dataset, device, tracker: BaseTracker, checkpoint: ModelCheckpoint, voting_runs=1, tracker_options={}, ): loaders = dataset.test_dataloaders for loader in loaders: stage_name = loader.dataset.name tracker.reset(stage_name) for i in range(voting_runs): with Ctq(loader) as tq_test_loader: for data in tq_test_loader: with torch.no_grad(): model.set_input(data, device) model.forward() tracker.track(model, **tracker_options) tq_test_loader.set_postfix(**tracker.get_metrics(), color=COLORS.TEST_COLOR) tracker.finalise(**tracker_options) tracker.print_summary()
def track(self, model: BaseModel, **kwargs): """ Add current model predictions (usually the result of a batch) to the tracking """ super().track(model) outputs = model.get_output() targets = model.get_labels() # Mask ignored label mask = targets != self._ignore_label outputs = outputs[mask] targets = targets[mask] outputs = SegmentationTracker.detach_tensor(outputs) targets = SegmentationTracker.detach_tensor(targets) if not torch.is_tensor(targets): targets = torch.from_numpy(targets) self._ap_meter.add(outputs, F.one_hot(targets, self._num_classes).bool()) outputs = self._convert(outputs) targets = self._convert(targets) if len(targets) == 0: return assert outputs.shape[0] == len(targets) self._confusion_matrix.count_predicted_batch(targets, np.argmax(outputs, 1)) self._acc = 100 * self._confusion_matrix.get_overall_accuracy() self._macc = 100 * self._confusion_matrix.get_mean_class_accuracy() self._miou = 100 * self._confusion_matrix.get_average_intersection_union( ) self._map = 100 * self._ap_meter.value().mean().item()
def track(self, model: BaseModel, **kwargs): """ Add current model predictions (usually the result of a batch) to the tracking """ super().track(model) outputs = model.get_output() targets = model.get_labels() # Mask ignored label mask = targets != self._ignore_label outputs = outputs[mask] targets = targets[mask] outputs = self._convert(outputs) targets = self._convert(targets) if len(targets) == 0: return assert outputs.shape[0] == len(targets) self._confusion_matrix.count_predicted_batch(targets, np.argmax(outputs, 1)) self._acc = 100 * self._confusion_matrix.get_overall_accuracy() self._macc = 100 * self._confusion_matrix.get_mean_class_accuracy() self._miou = 100 * self._confusion_matrix.get_average_intersection_union()
def run(cfg, model: BaseModel, dataset: BaseDataset, device, measurement_name: str): measurements = {} num_batches = getattr(cfg.debugging, "num_batches", np.inf) run_epoch(model, dataset.train_dataloader, device, num_batches) measurements["train"] = extract_histogram(model.get_spatial_ops(), normalize=False) if dataset.has_val_loader: run_epoch(model, dataset.val_dataloader, device, num_batches) measurements["val"] = extract_histogram(model.get_spatial_ops(), normalize=False) for loader in dataset.test_dataloaders: run_epoch(model, dataset.test_dataloaders, device, num_batches) measurements[loader.dataset.name] = extract_histogram( model.get_spatial_ops(), normalize=False) with open( os.path.join(DIR, "measurements/{}.pickle".format(measurement_name)), "wb") as f: pickle.dump(measurements, f)
def track(self, model: BaseModel, full_res=False, **kwargs): """ Add current model predictions (usually the result of a batch) to the tracking """ super().track(model) # Train mode or low res, nothing special to do if self._stage == "train" or not full_res: return # Test mode, compute votes in order to get full res predictions if self._test_area is None: self._test_area = self._dataset.test_data.clone() if self._test_area.y is None: raise ValueError("It seems that the test area data does not have labels (attribute y).") self._test_area.prediction_count = torch.zeros(self._test_area.y.shape[0], dtype=torch.int) self._test_area.votes = torch.zeros((self._test_area.y.shape[0], self._num_classes), dtype=torch.float) self._test_area.to(model.device) # Gather input to the model and check that it fits with the test set inputs = model.get_input() if inputs[SaveOriginalPosId.KEY] is None or inputs[SaveOriginalPosId.KEY].max() >= self._test_area.pos.shape[0]: raise ValueError( "The inputs given to the model do not have a %s attribute or this attribute does\ not correspond to the number of points in the test area point cloud." % SaveOriginalPosId.KEY ) # Set predictions outputs = model.get_output() self._test_area.votes[inputs[SaveOriginalPosId.KEY]] += outputs self._test_area.prediction_count[inputs[SaveOriginalPosId.KEY]] += 1
def __init__(self, option, model_type, dataset, modules): BaseModel.__init__(self, option) self.mode = option.loss_mode self.normalize_feature = option.normalize_feature self.loss_names = ["loss_reg", "loss"] self.metric_loss_module, self.miner_module = BaseModel.get_metric_loss_and_miner( getattr(option, "metric_loss", None), getattr(option, "miner", None) ) # Last Layer if option.mlp_cls is not None: last_mlp_opt = option.mlp_cls in_feat = last_mlp_opt.nn[0] self.FC_layer = Seq() for i in range(1, len(last_mlp_opt.nn)): self.FC_layer.append( str(i), Sequential( *[ Linear(in_feat, last_mlp_opt.nn[i], bias=False), FastBatchNorm1d(last_mlp_opt.nn[i], momentum=last_mlp_opt.bn_momentum), LeakyReLU(0.2), ] ), ) in_feat = last_mlp_opt.nn[i] if last_mlp_opt.dropout: self.FC_layer.append(Dropout(p=last_mlp_opt.dropout)) self.FC_layer.append(Linear(in_feat, in_feat, bias=False)) else: self.FC_layer = torch.nn.Identity()
def run_epoch(model: BaseModel, loader, device: str, num_batches: int): model.eval() with Ctq(loader) as tq_loader: for batch_idx, data in enumerate(tq_loader): if batch_idx < num_batches: process(model, data, device) else: break
def track(self, model: BaseModel): super().track(model) if self._stage != "train": batch_idx, batch_idx_target = model.get_batch_idx() batch_xyz, batch_xyz_target = model.get_xyz() batch_ind, batch_ind_target = model.get_ind() batch_feat, batch_feat_target = model.get_outputs() nb_batches = batch_idx.max() + 1 cum_sum = 0 cum_sum_target = 0 for b in range(nb_batches): xyz = batch_xyz[batch_idx == b] xyz_target = batch_xyz_target[batch_idx_target == b] feat = batch_feat[batch_idx == b] feat_target = batch_feat_target[batch_idx_target == b] # as we have concatenated ind, # we need to substract the cum_sum because we deal # with each batch independently ind = batch_ind[batch_idx == b] - cum_sum ind_target = batch_ind_target[batch_idx_target == b] - cum_sum_target cum_sum += len(xyz) cum_sum_target += len(xyz_target) rand = torch.randperm(len(feat))[:self.num_points] rand_target = torch.randperm( len(feat_target))[:self.num_points] matches_gt = torch.stack([ind, ind_target]).T T_gt = estimate_transfo(xyz[matches_gt[:, 0]], xyz_target[matches_gt[:, 1]]) matches_pred = get_matches(feat[rand], feat_target[rand_target]) T_pred = fast_global_registration( xyz[rand][matches_pred[:, 0]], xyz_target[rand_target][matches_pred[:, 1]]) hit_ratio = compute_hit_ratio( xyz[rand][matches_pred[:, 0]], xyz_target[rand_target][matches_pred[:, 1]], T_gt, self.tau_1) trans_error, rot_error = compute_transfo_error(T_pred, T_gt) self._hit_ratio.add(hit_ratio.item()) self._feat_match_ratio.add( float(hit_ratio.item() > self.tau_2)) self._trans_error.add(trans_error.item()) self._rot_error.add(rot_error.item())
def run(model: BaseModel, dataset: BaseDataset, device, output_path, cfg): # Set dataloaders num_fragment = dataset.num_fragment if cfg.data.is_patch: for i in range(num_fragment): dataset.set_patches(i) dataset.create_dataloaders( model, cfg.batch_size, False, cfg.num_workers, False, ) loader = dataset.test_dataloaders()[0] features = [] scene_name, pc_name = dataset.get_name(i) with Ctq(loader) as tq_test_loader: for data in tq_test_loader: # pcd = open3d.geometry.PointCloud() # pcd.points = open3d.utility.Vector3dVector(data.pos[0].numpy()) # open3d.visualization.draw_geometries([pcd]) with torch.no_grad(): model.set_input(data, device) model.forward() features.append(model.get_output().cpu()) features = torch.cat(features, 0).numpy() log.info("save {} from {} in {}".format(pc_name, scene_name, output_path)) save(output_path, scene_name, pc_name, dataset.base_dataset[i].to("cpu"), features) else: dataset.create_dataloaders( model, 1, False, cfg.num_workers, False, ) loader = dataset.test_dataloaders()[0] with Ctq(loader) as tq_test_loader: for i, data in enumerate(tq_test_loader): with torch.no_grad(): model.set_input(data, device) model.forward() features = model.get_output()[0] # batch of 1 save(output_path, scene_name, pc_name, data.to("cpu"), features)
def save_best_models_under_current_metrics( self, model: BaseModel, metrics_holder: dict, metric_func: dict, **kwargs ): """[This function is responsible to save checkpoint under the current metrics and their associated DEFAULT_METRICS_FUNC] Arguments: model {[BaseModel]} -- [Model] metrics_holder {[Dict]} -- [Need to contain stage, epoch, current_metrics] """ metrics = metrics_holder["current_metrics"] stage = metrics_holder["stage"] epoch = metrics_holder["epoch"] stats = self._checkpoint.stats state_dict = copy.deepcopy(model.state_dict()) current_stat = {} current_stat["epoch"] = epoch models_to_save = self._checkpoint.models if stage not in stats: stats[stage] = [] if stage == "train": models_to_save[Checkpoint._LATEST] = state_dict else: if len(stats[stage]) > 0: latest_stats = stats[stage][-1] msg = "" improved_metric = 0 for metric_name, current_metric_value in metrics.items(): current_stat[metric_name] = current_metric_value metric_func = self.find_func_from_metric_name(metric_name, metric_func) best_metric_from_stats = latest_stats.get("best_{}".format(metric_name), current_metric_value) best_value = metric_func(best_metric_from_stats, current_metric_value) current_stat["best_{}".format(metric_name)] = best_value # This new value seems to be better under metric_func if (self._selection_stage == stage) and ( current_metric_value == best_value ): # Update the model weights models_to_save["best_{}".format(metric_name)] = state_dict msg += "{}: {} -> {}, ".format(metric_name, best_metric_from_stats, best_value) improved_metric += 1 if improved_metric > 0: colored_print(COLORS.VAL_COLOR, msg[:-2]) else: # stats[stage] is empty. for metric_name, metric_value in metrics.items(): current_stat[metric_name] = metric_value current_stat["best_{}".format(metric_name)] = metric_value models_to_save["best_{}".format(metric_name)] = state_dict self._checkpoint.stats[stage].append(current_stat) self._checkpoint.save_objects(models_to_save, stage, current_stat, model.optimizer, model.schedulers, **kwargs)
def run(model: BaseModel, dataset: BaseDataset, device, output_path): loaders = dataset.test_dataloaders predicted = {} for loader in loaders: loader.dataset.name with Ctq(loader) as tq_test_loader: for data in tq_test_loader: with torch.no_grad(): model.set_input(data, device) model.forward() predicted = { **predicted, **dataset.predict_original_samples(data, model.conv_type, model.get_output()) } save(output_path, predicted)
def track(self, model: BaseModel): """ Add model predictions (accuracy) """ super().track(model) outputs = self._convert(model.get_output()) N = len(outputs) // 2 self._acc = compute_accuracy(outputs[:N], outputs[N:])
def _init_from_compact_format(self, opt, model_type, dataset, modules_lib): """Create a unetbasedmodel from the compact options format - where the same convolution is given for each layer, and arguments are given in lists """ self.down_modules = nn.ModuleList() self.inner_modules = nn.ModuleList() self.up_modules = nn.ModuleList() self.save_sampling_id = opt.down_conv.get('save_sampling_id') # Factory for creating up and down modules factory_module_cls = self._get_factory(model_type, modules_lib) down_conv_cls_name = opt.down_conv.module_name up_conv_cls_name = opt.up_conv.module_name if opt.get( 'up_conv') is not None else None self._factory_module = factory_module_cls( down_conv_cls_name, up_conv_cls_name, modules_lib) # Create the factory object # Loal module contains_global = hasattr(opt, "innermost") and opt.innermost is not None if contains_global: inners = self._create_inner_modules(opt.innermost, modules_lib) for inner in inners: self.inner_modules.append(inner) else: self.inner_modules.append(Identity()) # Down modules for i in range(len(opt.down_conv.down_conv_nn)): args = self._fetch_arguments(opt.down_conv, i, "DOWN") conv_cls = self._get_from_kwargs(args, "conv_cls") down_module = conv_cls(**args) self._save_sampling_and_search(down_module) self.down_modules.append(down_module) # Up modules if up_conv_cls_name: for i in range(len(opt.up_conv.up_conv_nn)): args = self._fetch_arguments(opt.up_conv, i, "UP") conv_cls = self._get_from_kwargs(args, "conv_cls") up_module = conv_cls(**args) self._save_upsample(up_module) self.up_modules.append(up_module) self.metric_loss_module, self.miner_module = BaseModel.get_metric_loss_and_miner( getattr(opt, "metric_loss", None), getattr(opt, "miner", None))
def run(model: BaseModel, dataset: BaseDataset, device, cfg): dataset.create_dataloaders( model, 1, False, cfg.training.num_workers, False, ) loader = dataset.test_dataloaders[0] list_res = [] with Ctq(loader) as tq_test_loader: for i, data in enumerate(tq_test_loader): with torch.no_grad(): model.set_input(data, device) model.forward() name_scene, name_pair_source, name_pair_target = dataset.test_dataset[ 0].get_name(i) input, input_target = model.get_input() xyz, xyz_target = input.pos, input_target.pos ind, ind_target = input.ind, input_target.ind matches_gt = torch.stack([ind, ind_target]).transpose(0, 1) feat, feat_target = model.get_output() rand = torch.randperm(len(feat))[:cfg.data.num_points] rand_target = torch.randperm( len(feat_target))[:cfg.data.num_points] res = dict(name_scene=name_scene, name_pair_source=name_pair_source, name_pair_target=name_pair_target) T_gt = estimate_transfo(xyz[matches_gt[:, 0]], xyz_target[matches_gt[:, 1]]) metric = compute_metrics( xyz[rand], xyz_target[rand_target], feat[rand], feat_target[rand_target], T_gt, sym=cfg.data.sym, tau_1=cfg.data.tau_1, tau_2=cfg.data.tau_2, rot_thresh=cfg.data.rot_thresh, trans_thresh=cfg.data.trans_thresh, use_ransac=cfg.data.use_ransac, ransac_thresh=cfg.data.first_subsampling, use_teaser=cfg.data.use_teaser, noise_bound_teaser=cfg.data.noise_bound_teaser, ) res = dict(**res, **metric) list_res.append(res) df = pd.DataFrame(list_res) output_path = os.path.join(cfg.training.checkpoint_dir, cfg.data.name, "matches") if not os.path.exists(output_path): os.makedirs(output_path, exist_ok=True) df.to_csv(osp.join(output_path, "final_res.csv")) print(df.groupby("name_scene").mean())
def train_epoch( epoch: int, model: BaseModel, dataset, device: str, tracker: BaseTracker, checkpoint: ModelCheckpoint, visualizer: Visualizer, debugging, ): early_break = getattr(debugging, "early_break", False) profiling = getattr(debugging, "profiling", False) model.train() tracker.reset("train") visualizer.reset(epoch, "train") train_loader = dataset.train_dataloader iter_data_time = time.time() with Ctq(train_loader) as tq_train_loader: for i, data in enumerate(tq_train_loader): model.set_input(data, device) t_data = time.time() - iter_data_time iter_start_time = time.time() model.optimize_parameters(epoch, dataset.batch_size) if i % 10 == 0: tracker.track(model) tq_train_loader.set_postfix(**tracker.get_metrics(), data_loading=float(t_data), iteration=float(time.time() - iter_start_time), color=COLORS.TRAIN_COLOR) if visualizer.is_active: visualizer.save_visuals(model.get_current_visuals()) iter_data_time = time.time() if early_break: break if profiling: if i > getattr(debugging, "num_batches", 50): return 0 metrics = tracker.publish(epoch) checkpoint.save_best_models_under_current_metrics(model, metrics, tracker.metric_func) log.info("Learning rate = %f" % model.learning_rate)
def __init__(self, option, model_type, dataset, modules): BaseModel.__init__(self, option) option_unet = option.option_unet self.normalize_feature = option.normalize_feature self.grid_size = option_unet.grid_size self.unet = UnetMSparseConv3d( option_unet.backbone, input_nc=option_unet.input_nc, pointnet_nn=option_unet.pointnet_nn, post_mlp_nn=option_unet.post_mlp_nn, pre_mlp_nn=option_unet.pre_mlp_nn, add_pos=option_unet.add_pos, add_pre_x=option_unet.add_pre_x, aggr=option_unet.aggr, backend=option.backend, ) if option.mlp_cls is not None: last_mlp_opt = option.mlp_cls self.FC_layer = Seq() for i in range(1, len(last_mlp_opt.nn)): self.FC_layer.append( nn.Sequential(*[ nn.Linear(last_mlp_opt.nn[i - 1], last_mlp_opt.nn[i], bias=False), FastBatchNorm1d(last_mlp_opt.nn[i], momentum=last_mlp_opt.bn_momentum), nn.LeakyReLU(0.2), ])) if last_mlp_opt.dropout: self.FC_layer.append(nn.Dropout(p=last_mlp_opt.dropout)) else: self.FC_layer = torch.nn.Identity() self.head = nn.Sequential( nn.Linear(option.output_nc, dataset.num_classes)) self.loss_names = ["loss_seg"]
def eval_epoch( model: BaseModel, dataset, device, tracker: BaseTracker, checkpoint: ModelCheckpoint, voting_runs=1, tracker_options={}, ): tracker.reset("val") loader = dataset.val_dataloader for i in range(voting_runs): with Ctq(loader) as tq_val_loader: for data in tq_val_loader: with torch.no_grad(): model.set_input(data, device) model.forward() tracker.track(model, **tracker_options) tq_val_loader.set_postfix(**tracker.get_metrics(), color=COLORS.VAL_COLOR) tracker.finalise(**tracker_options) tracker.print_summary()
def _init_from_compact_format(self, opt, model_type, dataset, modules_lib): """Create a backbonebasedmodel from the compact options format - where the same convolution is given for each layer, and arguments are given in lists """ num_convs = len(opt.down_conv.down_conv_nn) self.down_modules = nn.ModuleList() factory_module_cls = self._get_factory(model_type, modules_lib) down_conv_cls_name = opt.down_conv.module_name self._factory_module = factory_module_cls(down_conv_cls_name, None, modules_lib) # Down modules for i in range(num_convs): args = self._fetch_arguments(opt.down_conv, i, "DOWN") conv_cls = self._get_from_kwargs(args, "conv_cls") down_module = conv_cls(**args) self._save_sampling_and_search(down_module) self.down_modules.append(down_module) self.metric_loss_module, self.miner_module = BaseModel.get_metric_loss_and_miner( getattr(opt, "metric_loss", None), getattr(opt, "miner", None))
def test_epoch( epoch: int, model: BaseModel, dataset, device, tracker: BaseTracker, checkpoint: ModelCheckpoint, visualizer: Visualizer, debugging, ): early_break = getattr(debugging, "early_break", False) model.eval() loaders = dataset.test_dataloaders for loader in loaders: stage_name = loader.dataset.name tracker.reset(stage_name) visualizer.reset(epoch, stage_name) with Ctq(loader) as tq_test_loader: for data in tq_test_loader: with torch.no_grad(): model.set_input(data, device) model.forward() tracker.track(model) tq_test_loader.set_postfix(**tracker.get_metrics(), color=COLORS.TEST_COLOR) if visualizer.is_active: visualizer.save_visuals(model.get_current_visuals()) if early_break: break tracker.finalise() metrics = tracker.publish(epoch) tracker.print_summary() checkpoint.save_best_models_under_current_metrics( model, metrics, tracker.metric_func)
def eval_epoch( epoch: int, model: BaseModel, dataset, device, tracker: BaseTracker, checkpoint: ModelCheckpoint, visualizer: Visualizer, debugging, ): early_break = getattr(debugging, "early_break", False) model.eval() tracker.reset("val") visualizer.reset(epoch, "val") loader = dataset.val_dataloader with Ctq(loader) as tq_val_loader: for data in tq_val_loader: with torch.no_grad(): model.set_input(data, device) model.forward() tracker.track(model) tq_val_loader.set_postfix(**tracker.get_metrics(), color=COLORS.VAL_COLOR) if visualizer.is_active: visualizer.save_visuals(model.get_current_visuals()) if early_break: break metrics = tracker.publish(epoch) tracker.print_summary() checkpoint.save_best_models_under_current_metrics(model, metrics, tracker.metric_func)
def run(model: BaseModel, dataset: BaseDataset, device, cfg): reg_thresh = cfg.data.registration_recall_thresh if reg_thresh is None: reg_thresh = 0.2 print(time.strftime("%Y%m%d-%H%M%S")) dataset.create_dataloaders( model, 1, False, cfg.training.num_workers, False, ) loader = dataset.test_dataloaders[0] list_res = [] with Ctq(loader) as tq_test_loader: for i, data in enumerate(tq_test_loader): with torch.no_grad(): t0 = time.time() model.set_input(data, device) model.forward() t1 = time.time() name_scene, name_pair_source, name_pair_target = dataset.test_dataset[0].get_name(i) input, input_target = model.get_input() xyz, xyz_target = input.pos, input_target.pos ind, ind_target = input.ind, input_target.ind matches_gt = torch.stack([ind, ind_target]).transpose(0, 1) feat, feat_target = model.get_output() # rand = voxel_selection(xyz, grid_size=0.06, min_points=cfg.data.min_points) # rand_target = voxel_selection(xyz_target, grid_size=0.06, min_points=cfg.data.min_points) rand = torch.randperm(len(feat))[: cfg.data.num_points] rand_target = torch.randperm(len(feat_target))[: cfg.data.num_points] res = dict(name_scene=name_scene, name_pair_source=name_pair_source, name_pair_target=name_pair_target) T_gt = estimate_transfo(xyz[matches_gt[:, 0]], xyz_target[matches_gt[:, 1]]) t2 = time.time() metric = compute_metrics( xyz[rand], xyz_target[rand_target], feat[rand], feat_target[rand_target], T_gt, sym=cfg.data.sym, tau_1=cfg.data.tau_1, tau_2=cfg.data.tau_2, rot_thresh=cfg.data.rot_thresh, trans_thresh=cfg.data.trans_thresh, use_ransac=cfg.data.use_ransac, ransac_thresh=cfg.data.first_subsampling, use_teaser=cfg.data.use_teaser, noise_bound_teaser=cfg.data.noise_bound_teaser, xyz_gt=xyz[matches_gt[:, 0]], xyz_target_gt=xyz_target[matches_gt[:, 1]], registration_recall_thresh=reg_thresh, ) res = dict(**res, **metric) res["time_feature"] = t1 - t0 res["time_feature_per_point"] = (t1 - t0) / (len(input.pos) + len(input_target.pos)) res["time_prep"] = t2 - t1 list_res.append(res) df = pd.DataFrame(list_res) output_path = os.path.join(cfg.training.checkpoint_dir, cfg.data.name, "matches") if not os.path.exists(output_path): os.makedirs(output_path, exist_ok=True) df.to_csv(osp.join(output_path, "final_res_{}.csv".format(time.strftime("%Y%m%d-%H%M%S")))) print(df.groupby("name_scene").mean())
def __init__(self, option): BaseModel.__init__(self, option)
def run(model: BaseModel, dataset: BaseDataset, device, cfg): print(time.strftime("%Y%m%d-%H%M%S")) dataset.create_dataloaders( model, 1, False, cfg.training.num_workers, False, ) loader = dataset.test_dataset[0] ind = 0 if cfg.ind is not None: ind = cfg.ind t = 5 if cfg.t is not None: t = cfg.t r = 0.1 if cfg.r is not None: r = cfg.r print(loader) print(ind) data = loader[ind] data.batch = torch.zeros(len(data.pos)).long() data.batch_target = torch.zeros(len(data.pos_target)).long() print(data) with torch.no_grad(): model.set_input(data, device) model.forward() name_scene, name_pair_source, name_pair_target = dataset.test_dataset[ 0].get_name(ind) print(name_scene, name_pair_source, name_pair_target) input, input_target = model.get_input() xyz, xyz_target = input.pos, input_target.pos ind, ind_target = input.ind, input_target.ind matches_gt = torch.stack([ind, ind_target]).transpose(0, 1) feat, feat_target = model.get_output() # rand = voxel_selection(xyz, grid_size=0.06, min_points=cfg.data.min_points) # rand_target = voxel_selection(xyz_target, grid_size=0.06, min_points=cfg.data.min_points) rand = torch.randperm(len(feat))[:cfg.data.num_points] rand_target = torch.randperm(len(feat_target))[:cfg.data.num_points] T_gt = estimate_transfo(xyz[matches_gt[:, 0]].clone(), xyz_target[matches_gt[:, 1]].clone()) matches_pred = get_matches(feat[rand], feat_target[rand_target], sym=cfg.data.sym) # For color inliers = (torch.norm( xyz[rand][matches_pred[:, 0]] @ T_gt[:3, :3].T + T_gt[:3, 3] - xyz_target[rand_target][matches_pred[:, 1]], dim=1, ) < cfg.data.tau_1) # compute transformation T_teaser = teaser_pp_registration( xyz[rand][matches_pred[:, 0]], xyz_target[rand_target][matches_pred[:, 1]], noise_bound=cfg.data.tau_1) pcd_source = torch2o3d(input, [1, 0.7, 0.1]) pcd_target = torch2o3d(input_target, [0, 0.15, 0.9]) open3d.visualization.draw_geometries([pcd_source, pcd_target]) pcd_source.transform(T_teaser.cpu().numpy()) open3d.visualization.draw_geometries([pcd_source, pcd_target]) pcd_source.transform(np.linalg.inv(T_teaser.cpu().numpy())) rand_ind = torch.randperm(len(rand[matches_pred[:, 0]]))[:250] pcd_source.transform(T_gt.cpu().numpy()) kp_s = torch2o3d(input, ind=rand[matches_pred[:, 0]][rand_ind]) kp_s.transform(T_gt.cpu().numpy()) kp_t = torch2o3d(input_target, ind=rand_target[matches_pred[:, 1]][rand_ind]) match_visualizer(pcd_source, kp_s, pcd_target, kp_t, inliers[rand_ind].cpu().numpy(), radius=r, t=t)