Exemple #1
0
def get_overlap_points(pc, hist_info, config, c=1, pb_pos=1):
    # Pull points out of `pc` from overlap information to be used in dataset
    # creation.
    #   `hist_info`: tuple (hist, bins) 
    #   `c` : the count of required overlap points to exist in a bin for it to
    #       to count as being "in the overlap." Higher values of c grab points 
    #       more likely to be in the overlap.

    
    indices = np.full(pc.shape[0], False, dtype=bool)
    process_list = []
    hist, (xedges, yedges, zedges) = hist_info
    my_func = partial(get_indices, pc)
    workers = config['workers']
    
    # if str(type(pc)) != "<class 'sharedmem.sharedmem.anonymousmemmap'>":
    #     exit("ERROR : PC isn't in shared memory!")
    
    h_iter = np.array(np.meshgrid(
        np.arange(hist.shape[0]), 
        np.arange(hist.shape[1]),
        np.arange(hist.shape[2])
    )).T.reshape(-1, 3)
    
    pbar = get_pbar(
        h_iter,
        len(h_iter),
        "Building Processes", 
        pb_pos, disable=config['tqdm'])

    for t in pbar:
        i, j, k = t
        if hist[i][j][k] > c:
            x1, x2 = xedges[i], xedges[i+1]
            y1, y2 = yedges[j], yedges[j+1]
            z1, z2 = zedges[k], zedges[k+1]
            process_list.append((x1, x2, y1, y2, z1, z2))

            
    # multiprocessing - this is maybe 2x as fast with 8 workers?
    with Pool(workers) as p: 
        sub_pbar = get_pbar(
            p.imap_unordered(my_func, process_list),
            len(process_list),
            "Querying AOI",
            pb_pos, disable=config['tqdm']
            )

        for new_indices in sub_pbar:
            indices = indices | new_indices

    # single threaded
    # for t in tqdm(process_list, desc="  "*pb_pos + "Querying AOI", position=pb_pos, leave=False):
    #     indices = indices | get_indices(t)
    
    return indices
Exemple #2
0
def save_neighborhoods_hdf5(aoi, query, source_scan, config, chunk_size=5000, pb_pos=2):
    with h5py.File(config['dataset_path'], "a") as f:
        # the goal is to load as little of this into memory at once
        aoi_ = {}; query_ = {}
        train_idx = np.full(len(aoi), False, dtype=bool)
        train_size = int(config['splits']['train']*aoi.shape[0])

        # set some percentage for training, shuffle.
        train_idx[:train_size] = True; np.random.shuffle(train_idx)

        aoi_['train'] = aoi[train_idx]
        aoi_['test'] = aoi[~train_idx]
        query_['train'] = query[train_idx]
        query_['test'] = query[~train_idx]

        chunk_size = f['train'].chunks[0]  # this is the same for test/train
        sub_pbar = get_pbar(
            list(f.keys()),
            len(list(f.keys())),
            "Saving Neighborhoods []",
            pb_pos, disable=config['tqdm'])

        for split in sub_pbar:
            start = f[split].shape[0]
            end = start + aoi_[split].shape[0]
            slices = (slice(start, end), slice(0, config['max_n_size']+1), slice(0, 9))

            f[split].resize(end, axis=0)
            sub_pbar2 = get_pbar(
                f[split].iter_chunks(sel=slices),
                int(np.ceil(aoi_[split].shape[0]/chunk_size)),  # this could be more clear
                "Saving chunks",
                pb_pos+1, disable=config['tqdm'])
            
            # indices for the h5 chunks begin at start
            # indices for aoi and query begin at 0
            for idx, chunk in enumerate(sub_pbar2):
                # chunk is a tuple of slices with each element corresponding to
                #   a dimension of the dataset. 0 is the first axis. 
                aoi_chunk = aoi_[split][chunk[0].start-start:chunk[0].stop-start]
                aoi_chunk = np.expand_dims(aoi_chunk, 1)

                query_chunk = query_[split][chunk[0].start-start:chunk[0].stop-start]
                neighborhoods = np.concatenate((
                    aoi_chunk, source_scan[query_chunk]),
                    axis=1)

                f[split][chunk] = neighborhoods
Exemple #3
0
def filter_aoi(kd, aoi, config, pb_pos=1):
    # Querying uses a large amount of memory, use chunking to keep 
    #   the footprint small
    keep = []
    max_chunk_size = config['max_chunk_size']

    max_idx = int(np.ceil(aoi.shape[0] / max_chunk_size))
    sub_pbar = get_pbar(
            range(0, aoi.shape[0], max_chunk_size),
            max_idx,
            "Filtering AOI",
            pb_pos, disable=config['tqdm']
            )

    for i in sub_pbar:
        current_chunk = aoi[i:i+max_chunk_size]
        query = kdtree._query(kd, 
                              current_chunk[:, :3], 
                              k=config['max_n_size'], dmax=1)


        for j in range(len(query)):
            if len(query[j]) == config['max_n_size']:
                keep.append(i+j)

    return aoi[keep]
Exemple #4
0
def save_neighborhoods_hdf5_eval(aoi, query, source_scan, config, chunk_size=5000, pb_pos=2):
    with h5py.File(config['eval_dataset'], "a") as f:
        workers = config['workers']
        curr_idx = 1; max_idx = int(np.ceil(aoi.shape[0] / chunk_size))
        sub_pbar = get_pbar(
            range(0, aoi.shape[0], chunk_size),
            np.ceil(aoi.shape[0]/chunk_size),
            "Saving Neighborhoods",
            pb_pos, disable=config['tqdm'])

        for i in sub_pbar:
            aoi_chunk = aoi[i:i+chunk_size, :]
            query_chunk = query[i:i+chunk_size, :]

            aoi_chunk = np.expand_dims(aoi_chunk, 1)
            neighborhoods = np.concatenate(
                (aoi_chunk, source_scan[query_chunk]),
                axis=1)

            f['eval'][i:i+chunk_size] = neighborhoods
Exemple #5
0
def resample_aoi(aoi, igroup_bounds, max_size, config, pb_pos=2):
    # We want to resample the intensities here to be balanced
    #   across the range of intensities. 
    aoi_resampled = np.empty((0, aoi.shape[1]))
    sub_pbar = get_pbar(
        igroup_bounds,
        len(igroup_bounds),
        "Resampling AOI",
        pb_pos, disable=config['tqdm'])

    for (l, h) in sub_pbar:
        strata = aoi[(l <= aoi[:, 3]) & (aoi[:, 3] < h)]
        # auto append if strata is small
        if strata.shape[0] <= max_size:
            aoi_resampled = np.concatenate((
                aoi_resampled, strata))

        # random sample if large
        else:
            sample = np.random.choice(len(strata), max_size)
            aoi_resampled = np.concatenate((
                aoi_resampled, strata[sample]))

    return aoi_resampled
Exemple #6
0
def harmonize(model,
              source_scan_path,
              target_scan_num,
              config,
              save=False,
              sample_size=None):

    harmonized_path = Path(config['dataset']['harmonized_path'])
    plots_path = harmonized_path / "plots"
    plots_path.mkdir(exist_ok=True, parents=True)

    n_size = config['train']['neighborhood_size']
    b_size = config['train']['batch_size']
    chunk_size = config['dataset']['dataloader_size']
    transforms = get_transforms(config)
    G = GlobalShift(**config["dataset"])

    source_scan = np.load(source_scan_path)

    if config['dataset']['shift']:
        source_scan = G(source_scan)

    source_scan_num = int(source_scan[0, 8])

    if sample_size is not None:
        sample = np.random.choice(source_scan.shape[0], sample_size)
    else:
        sample = np.arange(source_scan.shape[0])

    model = model.to(config['train']['device'])
    model.eval()

    kd = kdtree._build(source_scan[:, :3])

    query = kdtree._query(kd, source_scan[sample, :3], k=n_size)

    query = np.array(query)
    size = len(query)

    hz = torch.empty(size).double()
    ip = torch.empty(size).double()
    cr = torch.empty(size).double()

    running_loss = 0

    pbar1 = get_pbar(range(0, len(query), chunk_size),
                     int(np.ceil(source_scan.shape[0] / chunk_size)),
                     f"Hzing Scan {source_scan_num}-->{target_scan_num}",
                     0,
                     leave=True,
                     disable=config['dataset']['tqdm'])

    for i in pbar1:
        query_chunk = query[i:i + chunk_size, :]
        source_chunk = source_scan[i:i + chunk_size, :]
        source_chunk = np.expand_dims(source_chunk, 1)

        neighborhoods = np.concatenate(
            (source_chunk, source_scan[query_chunk]), axis=1)

        dataset = LidarDatasetNP(neighborhoods, transform=transforms)

        dataloader = DataLoader(dataset,
                                batch_size=b_size,
                                num_workers=config['train']['num_workers'])

        pbar2 = get_pbar(dataloader,
                         len(dataloader),
                         "  Processing Chunk",
                         1,
                         disable=config['dataset']['tqdm'])

        with torch.no_grad():
            for j, batch in enumerate(pbar2):
                batch[:, 0,
                      -1] = target_scan_num  # specify that we wish to harmonize

                # batch = torch.tensor(np.expand_dims(ex, 0))
                batch = batch.to(config['train']['device'])

                # dublin specific?
                h_target = batch[:, 0, 3].clone()
                i_target = batch[:, 1, 3].clone()

                harmonization, interpolation, _ = model(batch)

                ldx = i + (j * b_size)
                hdx = i + (j + 1) * b_size

                hz[ldx:hdx] = harmonization.cpu().squeeze()
                ip[ldx:hdx] = interpolation.cpu().squeeze()
                cr[ldx:hdx] = i_target.cpu()  # corruption

                loss = torch.mean(torch.abs(harmonization.squeeze() -
                                            h_target))
                running_loss += loss.item()
                pbar2.set_postfix({"loss": f"{running_loss/(i+j+1):.3f}"})

    # visualize results
    hz = hz.numpy()
    hz = np.clip(hz, 0, 1)
    ip = ip.numpy()
    ip = np.clip(ip, 0, 1)
    cr = cr.numpy()
    cr = np.expand_dims(cr, 1)

    if config['dataset']['name'] == "dublin" and sample_size is None:
        create_kde(source_scan[sample, 3],
                   hz.squeeze(),
                   xlabel="ground truth harmonization",
                   ylabel="predicted harmonization",
                   output_path=plots_path /
                   f"{source_scan_num}-{target_scan_num}_harmonization.png")

        create_kde(cr.squeeze(),
                   ip.squeeze(),
                   xlabel="ground truth interpolation",
                   ylabel="predicted interpolation",
                   output_path=plots_path /
                   f"{source_scan_num}-{target_scan_num}_interpolation.png")

        create_kde(source_scan[sample, 3],
                   cr.squeeze(),
                   xlabel="ground truth",
                   ylabel="corruption",
                   output_path=plots_path /
                   f"{source_scan_num}-{target_scan_num}_corruption.png")

    # insert results into original scan
    harmonized_scan = np.hstack(
        (source_scan[sample, :3], np.expand_dims(hz, 1), source_scan[sample,
                                                                     4:]))

    if config['dataset']['name'] == "dublin":
        scan_error = np.mean(np.abs((source_scan[sample, 3]) - hz.squeeze()))
        print(f"Scan {source_scan_num} Harmonize MAE: {scan_error}")

    if save:
        np.save((Path(config['dataset']['harmonized_path']) /
                 (str(source_scan_num) + ".npy")), harmonized_scan)

    return harmonized_scan
Exemple #7
0
def train(dataloaders, config):
    # Dataloaders as dict {'train': dataloader, 'val': dataloader, 'test':...}

    ckpt_path = None  # not implemented yet
    epochs = config['train']['epochs']
    n_size = config['train']['neighborhood_size']
    b_size = config['train']['batch_size']

    results_path = Path(
        f"{config['train']['results_path']}{config['dataset']['use_ss_str']}{config['dataset']['shift_str']}"
    )
    results_path.mkdir(parents=True, exist_ok=True)
    print(results_path)

    phases = [k for k in dataloaders]
    [print(f"{k}: {len(v)}") for k, v in dataloaders.items()]

    device = config['train']['device']
    model = IntensityNet(n_size,
                         interpolation_method="pointnet").double().to(device)

    criterions = {
        'harmonization': nn.SmoothL1Loss(),
        'interpolation': nn.SmoothL1Loss()
    }

    optimizer = Adam(model.parameters())
    scheduler = CyclicLR(
        optimizer,
        config['train']['min_lr'],
        config['train']['max_lr'],
        step_size_up=len(dataloaders['train']) // 2,
        # mode='triangular2',
        scale_fn=lambda x: 1 / ((5 / 4.)**(x - 1)),
        cycle_momentum=False)

    best_loss = 1000
    # pbar1 = tqdm(range(epochs), total=epochs, desc=f"Best Loss: {best_loss}", disable=config['train']['tqdm'])
    pbar1 = get_pbar(range(epochs),
                     epochs,
                     f"Best Loss: {best_loss}",
                     0,
                     disable=config['train']['tqdm'],
                     leave=True)

    loss_history = {"train": [], "test": []}

    for epoch in pbar1:
        for phase in phases:
            if phase == "train":
                model.train()
            else:
                model.eval()

            data = []

            running_loss = 0.0
            total = 0.0
            pbar2 = get_pbar(dataloaders[phase],
                             len(dataloaders[phase]),
                             f"{phase.capitalize()}: {epoch+1}/{epochs}",
                             1,
                             disable=config['train']['tqdm'])

            for idx, batch in enumerate(pbar2):
                output = forward_pass(model, phase, batch, criterions,
                                      optimizer, scheduler, device)

                data.append(output)
                running_loss += output['loss'].item()
                # total += config['train']['batch_size']

                pbar2.set_postfix({
                    "loss":
                    f"{running_loss/(idx+1):.3f}",
                    "lr":
                    f"{optimizer.param_groups[0]['lr']:.2E}"
                })

            running_loss /= len(dataloaders[phase])
            loss_history[phase].append(running_loss)

            if phase in ['val', 'test']:
                if running_loss < best_loss:
                    new_ckpt_path = results_path / f"{n_size}_epoch={epoch}.pt"
                    torch.save(model.state_dict(), new_ckpt_path)

                    if ckpt_path:
                        ckpt_path.unlink()  # delete previous checkpoint
                    ckpt_path = new_ckpt_path

                    best_loss = running_loss
                    pbar1.set_description(f"Best Loss: {best_loss:.3f}")
                    pbar1.set_postfix({"ckpt": f"{epoch}"})

                    # create kde for val
                    h_target, h_preds, h_mae, i_target, i_preds, i_mae = get_metrics(
                        data)
                    create_kde(h_target, h_preds, h_mae, i_target, i_preds,
                               i_mae, phase,
                               results_path / f"val_kde_{n_size}.png")

                create_loss_plot(loss_history, results_path / f"loss.png")

    return model, ckpt_path
def harmonization_visualization(gt_collection_path,
                                hm_collection_path,
                                sample_size=100000,
                                shift=False,
                                view=True):
    gt_files = {f.stem: f for f in gt_collection_path.glob("*.npy")}
    hm_files = {f.stem: f for f in hm_collection_path.glob("*.npy")}

    collection = np.empty((len(hm_files) * sample_size, 11))
    C = Corruption(**config)
    G = GlobalShift(**config)

    pbar = get_pbar(enumerate(hm_files),
                    len(hm_files),
                    "loading collection",
                    0,
                    leave=True)

    for idx, scan_num in pbar:
        scan = np.load(gt_files[scan_num])
        if shift:
            scan = G(scan)
        if scan_num != config['target_scan']:
            corruption = C(scan)[1:, 3]  # no copy of first point
            harmonization = np.load(hm_files[scan_num])[:, 3]  # intensity only

            scan = np.concatenate((scan[:, :4], np.expand_dims(
                corruption, 1), np.expand_dims(harmonization, 1), scan[:, 4:]),
                                  axis=1)
        else:
            scan = np.concatenate((scan[:, :4], np.expand_dims(
                scan[:, 3], 1), np.expand_dims(scan[:, 3], 1), scan[:, 4:]),
                                  axis=1)

        sample = np.random.choice(len(scan), size=sample_size)
        collection[idx * sample_size:(idx + 1) * (sample_size)] = scan[sample]

    if view:
        # view the collection

        # Center
        collection[:, :3] -= np.mean(collection[:, :3], axis=0)

        v = viewer(collection[:, :3])

        # show gt, corruption, and harmonization
        attr = [collection[:, 3], collection[:, 4], collection[:, 5]]
        v.color_map('jet', scale=[0, 1])
        v.attributes(*attr)
        mae = np.mean(np.abs(collection[:, 5] - collection[:, 3]))
        print("MAE: ", mae)

        # center the view
        v.set(lookat=[0, 200, 0], r=4250, theta=np.pi / 2, phi=-np.pi / 2)

        # remove background, axis, info
        v.set(bg_color=[1, 1, 1, 1],
              show_grid=False,
              show_axis=False,
              show_info=False)

        # take photos
        v.set(curr_attribute_id=0)
        v.capture('gt.png')
        v.set(curr_attribute_id=1)
        v.capture('corruption.png')
        v.set(curr_attribute_id=2)
        v.capture('fix.png')

        code.interact(local=locals())

    return collection