def get_overlap_points(pc, hist_info, config, c=1, pb_pos=1): # Pull points out of `pc` from overlap information to be used in dataset # creation. # `hist_info`: tuple (hist, bins) # `c` : the count of required overlap points to exist in a bin for it to # to count as being "in the overlap." Higher values of c grab points # more likely to be in the overlap. indices = np.full(pc.shape[0], False, dtype=bool) process_list = [] hist, (xedges, yedges, zedges) = hist_info my_func = partial(get_indices, pc) workers = config['workers'] # if str(type(pc)) != "<class 'sharedmem.sharedmem.anonymousmemmap'>": # exit("ERROR : PC isn't in shared memory!") h_iter = np.array(np.meshgrid( np.arange(hist.shape[0]), np.arange(hist.shape[1]), np.arange(hist.shape[2]) )).T.reshape(-1, 3) pbar = get_pbar( h_iter, len(h_iter), "Building Processes", pb_pos, disable=config['tqdm']) for t in pbar: i, j, k = t if hist[i][j][k] > c: x1, x2 = xedges[i], xedges[i+1] y1, y2 = yedges[j], yedges[j+1] z1, z2 = zedges[k], zedges[k+1] process_list.append((x1, x2, y1, y2, z1, z2)) # multiprocessing - this is maybe 2x as fast with 8 workers? with Pool(workers) as p: sub_pbar = get_pbar( p.imap_unordered(my_func, process_list), len(process_list), "Querying AOI", pb_pos, disable=config['tqdm'] ) for new_indices in sub_pbar: indices = indices | new_indices # single threaded # for t in tqdm(process_list, desc=" "*pb_pos + "Querying AOI", position=pb_pos, leave=False): # indices = indices | get_indices(t) return indices
def save_neighborhoods_hdf5(aoi, query, source_scan, config, chunk_size=5000, pb_pos=2): with h5py.File(config['dataset_path'], "a") as f: # the goal is to load as little of this into memory at once aoi_ = {}; query_ = {} train_idx = np.full(len(aoi), False, dtype=bool) train_size = int(config['splits']['train']*aoi.shape[0]) # set some percentage for training, shuffle. train_idx[:train_size] = True; np.random.shuffle(train_idx) aoi_['train'] = aoi[train_idx] aoi_['test'] = aoi[~train_idx] query_['train'] = query[train_idx] query_['test'] = query[~train_idx] chunk_size = f['train'].chunks[0] # this is the same for test/train sub_pbar = get_pbar( list(f.keys()), len(list(f.keys())), "Saving Neighborhoods []", pb_pos, disable=config['tqdm']) for split in sub_pbar: start = f[split].shape[0] end = start + aoi_[split].shape[0] slices = (slice(start, end), slice(0, config['max_n_size']+1), slice(0, 9)) f[split].resize(end, axis=0) sub_pbar2 = get_pbar( f[split].iter_chunks(sel=slices), int(np.ceil(aoi_[split].shape[0]/chunk_size)), # this could be more clear "Saving chunks", pb_pos+1, disable=config['tqdm']) # indices for the h5 chunks begin at start # indices for aoi and query begin at 0 for idx, chunk in enumerate(sub_pbar2): # chunk is a tuple of slices with each element corresponding to # a dimension of the dataset. 0 is the first axis. aoi_chunk = aoi_[split][chunk[0].start-start:chunk[0].stop-start] aoi_chunk = np.expand_dims(aoi_chunk, 1) query_chunk = query_[split][chunk[0].start-start:chunk[0].stop-start] neighborhoods = np.concatenate(( aoi_chunk, source_scan[query_chunk]), axis=1) f[split][chunk] = neighborhoods
def filter_aoi(kd, aoi, config, pb_pos=1): # Querying uses a large amount of memory, use chunking to keep # the footprint small keep = [] max_chunk_size = config['max_chunk_size'] max_idx = int(np.ceil(aoi.shape[0] / max_chunk_size)) sub_pbar = get_pbar( range(0, aoi.shape[0], max_chunk_size), max_idx, "Filtering AOI", pb_pos, disable=config['tqdm'] ) for i in sub_pbar: current_chunk = aoi[i:i+max_chunk_size] query = kdtree._query(kd, current_chunk[:, :3], k=config['max_n_size'], dmax=1) for j in range(len(query)): if len(query[j]) == config['max_n_size']: keep.append(i+j) return aoi[keep]
def save_neighborhoods_hdf5_eval(aoi, query, source_scan, config, chunk_size=5000, pb_pos=2): with h5py.File(config['eval_dataset'], "a") as f: workers = config['workers'] curr_idx = 1; max_idx = int(np.ceil(aoi.shape[0] / chunk_size)) sub_pbar = get_pbar( range(0, aoi.shape[0], chunk_size), np.ceil(aoi.shape[0]/chunk_size), "Saving Neighborhoods", pb_pos, disable=config['tqdm']) for i in sub_pbar: aoi_chunk = aoi[i:i+chunk_size, :] query_chunk = query[i:i+chunk_size, :] aoi_chunk = np.expand_dims(aoi_chunk, 1) neighborhoods = np.concatenate( (aoi_chunk, source_scan[query_chunk]), axis=1) f['eval'][i:i+chunk_size] = neighborhoods
def resample_aoi(aoi, igroup_bounds, max_size, config, pb_pos=2): # We want to resample the intensities here to be balanced # across the range of intensities. aoi_resampled = np.empty((0, aoi.shape[1])) sub_pbar = get_pbar( igroup_bounds, len(igroup_bounds), "Resampling AOI", pb_pos, disable=config['tqdm']) for (l, h) in sub_pbar: strata = aoi[(l <= aoi[:, 3]) & (aoi[:, 3] < h)] # auto append if strata is small if strata.shape[0] <= max_size: aoi_resampled = np.concatenate(( aoi_resampled, strata)) # random sample if large else: sample = np.random.choice(len(strata), max_size) aoi_resampled = np.concatenate(( aoi_resampled, strata[sample])) return aoi_resampled
def harmonize(model, source_scan_path, target_scan_num, config, save=False, sample_size=None): harmonized_path = Path(config['dataset']['harmonized_path']) plots_path = harmonized_path / "plots" plots_path.mkdir(exist_ok=True, parents=True) n_size = config['train']['neighborhood_size'] b_size = config['train']['batch_size'] chunk_size = config['dataset']['dataloader_size'] transforms = get_transforms(config) G = GlobalShift(**config["dataset"]) source_scan = np.load(source_scan_path) if config['dataset']['shift']: source_scan = G(source_scan) source_scan_num = int(source_scan[0, 8]) if sample_size is not None: sample = np.random.choice(source_scan.shape[0], sample_size) else: sample = np.arange(source_scan.shape[0]) model = model.to(config['train']['device']) model.eval() kd = kdtree._build(source_scan[:, :3]) query = kdtree._query(kd, source_scan[sample, :3], k=n_size) query = np.array(query) size = len(query) hz = torch.empty(size).double() ip = torch.empty(size).double() cr = torch.empty(size).double() running_loss = 0 pbar1 = get_pbar(range(0, len(query), chunk_size), int(np.ceil(source_scan.shape[0] / chunk_size)), f"Hzing Scan {source_scan_num}-->{target_scan_num}", 0, leave=True, disable=config['dataset']['tqdm']) for i in pbar1: query_chunk = query[i:i + chunk_size, :] source_chunk = source_scan[i:i + chunk_size, :] source_chunk = np.expand_dims(source_chunk, 1) neighborhoods = np.concatenate( (source_chunk, source_scan[query_chunk]), axis=1) dataset = LidarDatasetNP(neighborhoods, transform=transforms) dataloader = DataLoader(dataset, batch_size=b_size, num_workers=config['train']['num_workers']) pbar2 = get_pbar(dataloader, len(dataloader), " Processing Chunk", 1, disable=config['dataset']['tqdm']) with torch.no_grad(): for j, batch in enumerate(pbar2): batch[:, 0, -1] = target_scan_num # specify that we wish to harmonize # batch = torch.tensor(np.expand_dims(ex, 0)) batch = batch.to(config['train']['device']) # dublin specific? h_target = batch[:, 0, 3].clone() i_target = batch[:, 1, 3].clone() harmonization, interpolation, _ = model(batch) ldx = i + (j * b_size) hdx = i + (j + 1) * b_size hz[ldx:hdx] = harmonization.cpu().squeeze() ip[ldx:hdx] = interpolation.cpu().squeeze() cr[ldx:hdx] = i_target.cpu() # corruption loss = torch.mean(torch.abs(harmonization.squeeze() - h_target)) running_loss += loss.item() pbar2.set_postfix({"loss": f"{running_loss/(i+j+1):.3f}"}) # visualize results hz = hz.numpy() hz = np.clip(hz, 0, 1) ip = ip.numpy() ip = np.clip(ip, 0, 1) cr = cr.numpy() cr = np.expand_dims(cr, 1) if config['dataset']['name'] == "dublin" and sample_size is None: create_kde(source_scan[sample, 3], hz.squeeze(), xlabel="ground truth harmonization", ylabel="predicted harmonization", output_path=plots_path / f"{source_scan_num}-{target_scan_num}_harmonization.png") create_kde(cr.squeeze(), ip.squeeze(), xlabel="ground truth interpolation", ylabel="predicted interpolation", output_path=plots_path / f"{source_scan_num}-{target_scan_num}_interpolation.png") create_kde(source_scan[sample, 3], cr.squeeze(), xlabel="ground truth", ylabel="corruption", output_path=plots_path / f"{source_scan_num}-{target_scan_num}_corruption.png") # insert results into original scan harmonized_scan = np.hstack( (source_scan[sample, :3], np.expand_dims(hz, 1), source_scan[sample, 4:])) if config['dataset']['name'] == "dublin": scan_error = np.mean(np.abs((source_scan[sample, 3]) - hz.squeeze())) print(f"Scan {source_scan_num} Harmonize MAE: {scan_error}") if save: np.save((Path(config['dataset']['harmonized_path']) / (str(source_scan_num) + ".npy")), harmonized_scan) return harmonized_scan
def train(dataloaders, config): # Dataloaders as dict {'train': dataloader, 'val': dataloader, 'test':...} ckpt_path = None # not implemented yet epochs = config['train']['epochs'] n_size = config['train']['neighborhood_size'] b_size = config['train']['batch_size'] results_path = Path( f"{config['train']['results_path']}{config['dataset']['use_ss_str']}{config['dataset']['shift_str']}" ) results_path.mkdir(parents=True, exist_ok=True) print(results_path) phases = [k for k in dataloaders] [print(f"{k}: {len(v)}") for k, v in dataloaders.items()] device = config['train']['device'] model = IntensityNet(n_size, interpolation_method="pointnet").double().to(device) criterions = { 'harmonization': nn.SmoothL1Loss(), 'interpolation': nn.SmoothL1Loss() } optimizer = Adam(model.parameters()) scheduler = CyclicLR( optimizer, config['train']['min_lr'], config['train']['max_lr'], step_size_up=len(dataloaders['train']) // 2, # mode='triangular2', scale_fn=lambda x: 1 / ((5 / 4.)**(x - 1)), cycle_momentum=False) best_loss = 1000 # pbar1 = tqdm(range(epochs), total=epochs, desc=f"Best Loss: {best_loss}", disable=config['train']['tqdm']) pbar1 = get_pbar(range(epochs), epochs, f"Best Loss: {best_loss}", 0, disable=config['train']['tqdm'], leave=True) loss_history = {"train": [], "test": []} for epoch in pbar1: for phase in phases: if phase == "train": model.train() else: model.eval() data = [] running_loss = 0.0 total = 0.0 pbar2 = get_pbar(dataloaders[phase], len(dataloaders[phase]), f"{phase.capitalize()}: {epoch+1}/{epochs}", 1, disable=config['train']['tqdm']) for idx, batch in enumerate(pbar2): output = forward_pass(model, phase, batch, criterions, optimizer, scheduler, device) data.append(output) running_loss += output['loss'].item() # total += config['train']['batch_size'] pbar2.set_postfix({ "loss": f"{running_loss/(idx+1):.3f}", "lr": f"{optimizer.param_groups[0]['lr']:.2E}" }) running_loss /= len(dataloaders[phase]) loss_history[phase].append(running_loss) if phase in ['val', 'test']: if running_loss < best_loss: new_ckpt_path = results_path / f"{n_size}_epoch={epoch}.pt" torch.save(model.state_dict(), new_ckpt_path) if ckpt_path: ckpt_path.unlink() # delete previous checkpoint ckpt_path = new_ckpt_path best_loss = running_loss pbar1.set_description(f"Best Loss: {best_loss:.3f}") pbar1.set_postfix({"ckpt": f"{epoch}"}) # create kde for val h_target, h_preds, h_mae, i_target, i_preds, i_mae = get_metrics( data) create_kde(h_target, h_preds, h_mae, i_target, i_preds, i_mae, phase, results_path / f"val_kde_{n_size}.png") create_loss_plot(loss_history, results_path / f"loss.png") return model, ckpt_path
def harmonization_visualization(gt_collection_path, hm_collection_path, sample_size=100000, shift=False, view=True): gt_files = {f.stem: f for f in gt_collection_path.glob("*.npy")} hm_files = {f.stem: f for f in hm_collection_path.glob("*.npy")} collection = np.empty((len(hm_files) * sample_size, 11)) C = Corruption(**config) G = GlobalShift(**config) pbar = get_pbar(enumerate(hm_files), len(hm_files), "loading collection", 0, leave=True) for idx, scan_num in pbar: scan = np.load(gt_files[scan_num]) if shift: scan = G(scan) if scan_num != config['target_scan']: corruption = C(scan)[1:, 3] # no copy of first point harmonization = np.load(hm_files[scan_num])[:, 3] # intensity only scan = np.concatenate((scan[:, :4], np.expand_dims( corruption, 1), np.expand_dims(harmonization, 1), scan[:, 4:]), axis=1) else: scan = np.concatenate((scan[:, :4], np.expand_dims( scan[:, 3], 1), np.expand_dims(scan[:, 3], 1), scan[:, 4:]), axis=1) sample = np.random.choice(len(scan), size=sample_size) collection[idx * sample_size:(idx + 1) * (sample_size)] = scan[sample] if view: # view the collection # Center collection[:, :3] -= np.mean(collection[:, :3], axis=0) v = viewer(collection[:, :3]) # show gt, corruption, and harmonization attr = [collection[:, 3], collection[:, 4], collection[:, 5]] v.color_map('jet', scale=[0, 1]) v.attributes(*attr) mae = np.mean(np.abs(collection[:, 5] - collection[:, 3])) print("MAE: ", mae) # center the view v.set(lookat=[0, 200, 0], r=4250, theta=np.pi / 2, phi=-np.pi / 2) # remove background, axis, info v.set(bg_color=[1, 1, 1, 1], show_grid=False, show_axis=False, show_info=False) # take photos v.set(curr_attribute_id=0) v.capture('gt.png') v.set(curr_attribute_id=1) v.capture('corruption.png') v.set(curr_attribute_id=2) v.capture('fix.png') code.interact(local=locals()) return collection