Exemple #1
0
def preprocess_spad(spad_single, ambient_estimate, min_depth, max_depth, sid_obj):
    print("Processing SPAD data...")
    # Remove DC
    spad_denoised = remove_dc_from_spad_edge(spad_single,
                                             ambient=ambient_estimate,
                                             grad_th=3*np.sqrt(2*ambient_estimate))

    # Correct Falloff
    bin_edges = np.linspace(min_depth, max_depth, len(spad_denoised) + 1)
    bin_values = (bin_edges[1:] + bin_edges[:-1]) / 2
    spad_corrected = spad_denoised * bin_values ** 2

    # Scale SID object to maximize bin utilization
    min_depth_bin = np.min(np.nonzero(spad_corrected))
    max_depth_bin = np.max(np.nonzero(spad_corrected))
    min_depth_pred = bin_values[min_depth_bin]
    max_depth_pred = bin_values[max_depth_bin]
    sid_obj_pred = SID(sid_bins=sid_obj.sid_bins,
                       alpha=min_depth_pred,
                       beta=max_depth_pred,
                       offset=0.)

    # Convert to SID
    spad_sid = rescale_bins(spad_corrected[min_depth_bin:max_depth_bin+1],
                            min_depth_pred, max_depth_pred, sid_obj_pred)
    return spad_sid, sid_obj_pred, spad_denoised, spad_corrected
Exemple #2
0
def load_data(root_dir, min_depth, max_depth, dorn_mode, sid_bins, alpha, beta,
              offset, spad_config):
    """

    :param root_dir:  The root directory from which to load the dataset
    :param use_dorn_normalization: Whether or not to normalize the rgb images according to DORN statistics.
    :return: test: a NYUDepthv2TestDataset object.
    """
    test = NYUDepthv2TestDataset(root_dir, transform=None, dorn_mode=dorn_mode)

    if dorn_mode:
        transform_mean = np.array([[[103.0626, 115.9029,
                                     123.1516]]]).astype(np.float32)
    else:
        transform_mean = np.zeros((1, 1, 3)).astype(np.float32)
    transform_var = np.ones((1, 1, 3)).astype(np.float32)

    transform_list = [
        AddDepthMask(min_depth, max_depth, "depth_cropped"),
        Save(["rgb_cropped", "mask", "depth_cropped"], "_orig"),
    ]
    if dorn_mode:
        transform_list.append(
            ResizeAll((353, 257), keys=["rgb_cropped", "depth_cropped"]))
    transform_list += [
        Normalize(transform_mean, transform_var, key="rgb_cropped"),
        AddDepthMask(min_depth, max_depth, "depth_cropped"),
        SimulateSpadIntensity("depth_cropped_orig",
                              "rgb_cropped_orig",
                              "mask_orig",
                              "spad",
                              min_depth,
                              max_depth,
                              spad_config["spad_bins"],
                              spad_config["photon_count"],
                              spad_config["dc_count"],
                              spad_config["fwhm_ps"],
                              spad_config["use_intensity"],
                              spad_config["use_squared_falloff"],
                              sid_obj=SID(sid_bins, alpha, beta, offset)),
    ]
    # TODO: Determine if BGR or RGB
    # Answer: RGB - need to flip for DORN.
    if dorn_mode:
        print("Using dataset in DORN mode.")
        transform_list.append(
            ToTensorAll(keys=[
                "rgb_cropped", "rgb_cropped_orig", "depth_cropped",
                "depth_cropped_orig", "mask", "mask_orig", "spad"
            ]))
    else:
        print("Using dataset in Wonka mode.")
        # Don't flip channels of RGB input
        # Flip channels of rgb_cropped used for intensity
        transform_list.append(
            ToTensorAll(keys=["rgb_cropped", "depth_cropped", "mask", "spad"]))
    test.transform = transforms.Compose(transform_list)
    return test
def run(initializer, opt_eps, sid_bins, alpha, beta, offset):
    # Load Data
    dataset = np.load(os.path.join("data", initializer, "all_outputs.npy"))

    # Create SID object
    sid_obj = SID(sid_bins, alpha, beta, offset)
    outputs = []
    for i in range(len(dataset)):
        depth_init = dataset[i]
        print(depth_init.shape)
def cfg():
    data_dir = "data"
    calibration_file = os.path.join(data_dir, "calibration", "camera_params.mat")
    scenes = [
        # "8_29_lab_scene",
        # "8_29_kitchen_scene",
        # "8_29_conference_room_scene",
        # "8_30_conference_room2_scene",
        # "8_30_Hallway",
        # "8_30_poster_scene",
        "8_30_small_lab_scene",
    ]
    # Relative shift of projected depth to rgb (found empirically)
    offsets = [
        # (0, 0),
        # (-10, -8),
        # (-16, -12),
        # (-16, -12),
        # (0, 0),
        # (0, 0),
        (0, 0)
    ]


    output_dir = os.path.join("figures", "midas")

    bin_width_ps = 16
    bin_width_m = bin_width_ps*3e8/(2*1e12)
    min_depth_bin = np.floor(0.4/bin_width_m).astype('int')
    max_depth_bin = np.floor(9./bin_width_m).astype('int')
    min_depth = min_depth_bin * bin_width_m
    max_depth = (max_depth_bin + 1) * bin_width_m
    sid_obj_init = SID(sid_bins=140, alpha=min_depth, beta=max_depth, offset=0)
    ambient_max_depth_bin = 100

    cuda_device = "0"                       # The gpu index to run on. Should be a string
    os.environ["CUDA_VISIBLE_DEVICES"] = cuda_device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("using device: {} (CUDA_VISIBLE_DEVICES = {})".format(device,
                                                                os.environ["CUDA_VISIBLE_DEVICES"]))
def run(dataset_type, spad_file, densedepth_depth_file, hyper_string, sid_bins,
        alpha, beta, offset, intensity_ablation, vectorized, entry,
        save_outputs, small_run, output_dir):
    print("output dir: {}".format(output_dir))
    safe_makedir(output_dir)

    # Load all the data:
    print("Loading SPAD data from {}".format(spad_file))
    spad_dict = np.load(spad_file, allow_pickle=True).item()
    spad_data = spad_dict["spad"]
    intensity_data = spad_dict["intensity"]
    spad_config = spad_dict["config"]
    print("Loading depth data from {}".format(densedepth_depth_file))
    depth_data = np.load(densedepth_depth_file)
    dataset = load_data(channels_first=True, dataset_type=dataset_type)

    # Read SPAD config and determine proper course of action
    dc_count = spad_config["dc_count"]
    ambient = spad_config["dc_count"] / spad_config["spad_bins"]
    use_intensity = spad_config["use_intensity"]
    use_squared_falloff = spad_config["use_squared_falloff"]
    lambertian = spad_config["lambertian"]
    use_poisson = spad_config["use_poisson"]
    min_depth = spad_config["min_depth"]
    max_depth = spad_config["max_depth"]

    print("ambient: ", ambient)
    print("dc_count: ", dc_count)
    print("use_intensity: ", use_intensity)
    print("use_squared_falloff:", use_squared_falloff)
    print("lambertian:", lambertian)

    print("spad_data.shape", spad_data.shape)
    print("depth_data.shape", depth_data.shape)
    print("intensity_data.shape", intensity_data.shape)

    sid_obj_init = SID(sid_bins, alpha, beta, offset)

    if entry is None:
        metric_list = [
            "delta1", "delta2", "delta3", "rel_abs_diff", "rmse", "mse",
            "log10", "weight"
        ]
        metrics = np.zeros(
            (len(dataset) if not small_run else small_run, len(metric_list)))
        entry_list = []
        outputs = []
        times = []
        for i in range(depth_data.shape[0]):
            if small_run and i == small_run:
                break
            entry_list.append(i)
            print("Evaluating {}[{}]".format(dataset_type, i))
            spad = spad_data[i, ...]
            bin_edges = np.linspace(min_depth, max_depth, len(spad) + 1)
            bin_values = (bin_edges[1:] + bin_edges[:-1]) / 2
            # spad = preprocess_spad_ambient_estimate(spad, min_depth, max_depth,
            #                                             correct_falloff=use_squared_falloff,
            #                                             remove_dc= dc_count > 0.,
            #                                             global_min_depth=np.min(depth_data),
            #                                             n_std=1. if use_poisson else 0.01)
            # Rescale SPAD_data
            weights = np.ones_like(depth_data[i, 0, ...])
            # Ablation study: Turn off intensity, even if spad has been simulated with it.
            if use_intensity and not intensity_ablation:
                weights = intensity_data[i, 0, ...]

            if dc_count > 0.:
                spad = remove_dc_from_spad_edge(
                    spad,
                    ambient=ambient,
                    # grad_th=2*ambient)
                    grad_th=5 * np.sqrt(2 * ambient))
            # print(2*ambient)
            # print(5*np.sqrt(2*ambient))

            if use_squared_falloff:
                if lambertian:
                    spad = spad * bin_values**4
                else:
                    spad = spad * bin_values**2
            # Scale SID object to maximize bin utilization
            nonzeros = np.nonzero(spad)[0]
            if nonzeros.size > 0:
                min_depth_bin = np.min(nonzeros)
                max_depth_bin = np.max(nonzeros) + 1
                if max_depth_bin > len(bin_edges) - 2:
                    max_depth_bin = len(bin_edges) - 2
            else:
                min_depth_bin = 0
                max_depth_bin = len(bin_edges) - 2
            min_depth_pred = np.clip(bin_edges[min_depth_bin],
                                     a_min=1e-2,
                                     a_max=None)
            max_depth_pred = np.clip(bin_edges[max_depth_bin + 1],
                                     a_min=1e-2,
                                     a_max=None)
            # print(min_depth_pred)
            # print(max_depth_pred)
            sid_obj_pred = SID(sid_bins=sid_obj_init.sid_bins,
                               alpha=min_depth_pred,
                               beta=max_depth_pred,
                               offset=0.)
            spad_rescaled = rescale_bins(spad[min_depth_bin:max_depth_bin + 1],
                                         min_depth_pred, max_depth_pred,
                                         sid_obj_pred)
            start = process_time()
            pred, t = image_histogram_match_variable_bin(
                depth_data[i, 0, ...], spad_rescaled, weights, sid_obj_init,
                sid_obj_pred, vectorized)
            times.append(process_time() - start)
            # break
            # Calculate metrics
            gt = dataset[i]["depth_cropped"].unsqueeze(0)
            # print(gt.dtype)
            # print(pred.shape)
            # print(pred[20:30, 20:30])

            pred_metrics = get_depth_metrics(
                torch.from_numpy(pred).unsqueeze(0).unsqueeze(0).float(), gt,
                torch.ones_like(gt))

            for j, metric_name in enumerate(metric_list[:-1]):
                metrics[i, j] = pred_metrics[metric_name]

            metrics[i, -1] = np.size(pred)
            # Option to save outputs:
            if save_outputs:
                outputs.append(pred)
            print("\tAvg RMSE = {}".format(
                np.mean(metrics[:i + 1, metric_list.index("rmse")])))

        if save_outputs:
            np.save(
                os.path.join(output_dir,
                             "densedepth_{}_outputs.npy".format(hyper_string)),
                np.array(outputs))
        print("Avg Time: {}".format(np.mean(times)))
        # Save metrics using pandas
        metrics_df = pd.DataFrame(data=metrics,
                                  index=entry_list,
                                  columns=metric_list)
        metrics_df.to_pickle(path=os.path.join(
            output_dir, "densedepth_{}_metrics.pkl".format(hyper_string)))
        # Compute weighted averages:
        average_metrics = np.average(metrics_df.ix[:, :-1],
                                     weights=metrics_df.weight,
                                     axis=0)
        average_df = pd.Series(data=average_metrics, index=metric_list[:-1])
        average_df.to_csv(os.path.join(
            output_dir, "densedepth_{}_avg_metrics.csv".format(hyper_string)),
                          header=True)
        print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
            'd1', 'd2', 'd3', 'rel', 'rmse', 'log_10'))
        print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".
              format(average_metrics[0], average_metrics[1],
                     average_metrics[2], average_metrics[3],
                     average_metrics[4], average_metrics[6]))

        print("wrote results to {} ({})".format(output_dir, hyper_string))

    else:
        input_unbatched = dataset.get_item_by_id(entry)
        # for key in ["rgb", "albedo", "rawdepth", "spad", "mask", "rawdepth_orig", "mask_orig", "albedo_orig"]:
        #     input_[key] = input_[key].unsqueeze(0)
        from torch.utils.data._utils.collate import default_collate

        data = default_collate([input_unbatched])

        # Checks
        entry = data["entry"][0]
        i = int(entry)
        entry = entry if isinstance(entry, str) else entry.item()
        print("Evaluating {}[{}]".format(dataset_type, i))
        # Rescale SPAD
        spad = spad_data[i, ...]
        spad_rescaled = rescale_bins(spad, min_depth, max_depth, sid_obj)
        print("spad_rescaled", spad_rescaled)
        weights = np.ones_like(depth_data[i, 0, ...])
        if use_intensity:
            weights = intensity_data[i, 0, ...]
        # spad_rescaled = preprocess_spad_sid_gmm(spad_rescaled, sid_obj, use_squared_falloff, dc_count > 0.)
        # spad_rescaled = preprocess_spad_sid(spad_rescaled, sid_obj, use_squared_falloff, dc_count > 0.
        #                                     )

        if dc_count > 0.:
            spad_rescaled = remove_dc_from_spad(
                spad_rescaled,
                sid_obj.sid_bin_edges,
                sid_obj.sid_bin_values[:-2]**2,
                lam=1e1 if use_poisson else 1e-1,
                eps_rel=1e-5)
        if use_squared_falloff:
            spad_rescaled = spad_rescaled * sid_obj.sid_bin_values[:-2]**2
        # print(spad_rescaled)
        pred, _ = image_histogram_match(depth_data[i, 0, ...], spad_rescaled,
                                        weights, sid_obj)
        # break
        # Calculate metrics
        gt = data["depth_cropped"]
        print(gt.shape)
        print(pred.shape)
        print(gt[:, :, 40, 60])
        print(depth_data[i, 0, 40, 60])
        print("before rmse: ",
              np.sqrt(np.mean((gt.numpy() - depth_data[i, 0, ...])**2)))

        before_metrics = get_depth_metrics(
            torch.from_numpy(
                depth_data[i, 0, ...]).unsqueeze(0).unsqueeze(0).float(), gt,
            torch.ones_like(gt))
        pred_metrics = get_depth_metrics(
            torch.from_numpy(pred).unsqueeze(0).unsqueeze(0).float(), gt,
            torch.ones_like(gt))
        if save_outputs:
            np.save(
                os.path.join(
                    output_dir, "densedepth_{}[{}]_{}_out.npy".format(
                        dataset_type, entry, hyper_string)), pred)

        print("before:")
        print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
            'd1', 'd2', 'd3', 'rel', 'rmse', 'log_10'))
        print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".
              format(before_metrics["delta1"], before_metrics["delta2"],
                     before_metrics["delta3"], before_metrics["rel_abs_diff"],
                     before_metrics["rmse"], before_metrics["log10"]))
        print("after:")

        print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
            'd1', 'd2', 'd3', 'rel', 'rmse', 'log_10'))
        print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".
              format(pred_metrics["delta1"], pred_metrics["delta2"],
                     pred_metrics["delta3"], pred_metrics["rel_abs_diff"],
                     pred_metrics["rmse"], pred_metrics["log10"]))
def run(dataset_type, spad_file, densedepth_depth_file, hyper_string, sid_bins,
        alpha, beta, offset, lam, eps_rel, n_std, entry, save_outputs,
        small_run, subsampling, output_dir):
    # Load all the data:
    print("Loading SPAD data from {}".format(spad_file))
    spad_dict = np.load(spad_file, allow_pickle=True).item()
    spad_data = spad_dict["spad"]
    intensity_data = spad_dict["intensity"]
    spad_config = spad_dict["config"]
    print("Loading depth data from {}".format(densedepth_depth_file))
    depth_data = np.load(densedepth_depth_file)
    dataset = load_data(channels_first=True, dataset_type=dataset_type)

    # Read SPAD config and determine proper course of action
    dc_count = spad_config["dc_count"]
    use_intensity = spad_config["use_intensity"]
    use_squared_falloff = spad_config["use_squared_falloff"]
    use_poisson = spad_config["use_poisson"]
    min_depth = spad_config["min_depth"]
    max_depth = spad_config["max_depth"]

    print("dc_count: ", dc_count)
    print("use_intensity: ", use_intensity)
    print("use_squared_falloff:", use_squared_falloff)

    print("spad_data.shape", spad_data.shape)
    print("depth_data.shape", depth_data.shape)
    print("intensity_data.shape", intensity_data.shape)

    sid_obj = SID(sid_bins, alpha, beta, offset)

    if entry is None:
        metric_list = [
            "delta1", "delta2", "delta3", "rel_abs_diff", "rmse", "mse",
            "log10", "weight"
        ]
        print(len(dataset) // subsampling)
        metrics = np.zeros(
            (len(dataset) // subsampling + 1 if not small_run else small_run,
             len(metric_list)))
        entry_list = []
        outputs = []
        for i in range(depth_data.shape[0]):
            idx = i * subsampling

            if idx >= depth_data.shape[0] or (small_run and i >= small_run):
                break
            entry_list.append(idx)

            print("Evaluating {}[{}]".format(dataset_type, idx))
            spad = spad_data[idx, ...]
            # spad = preprocess_spad_ambient_estimate(spad, min_depth, max_depth,
            #                                             correct_falloff=use_squared_falloff,
            #                                             remove_dc= dc_count > 0.,
            #                                             global_min_depth=np.min(depth_data),
            #                                             n_std=1. if use_poisson else 0.01)
            # Rescale SPAD_data
            spad_rescaled = rescale_bins(spad, min_depth, max_depth, sid_obj)
            weights = np.ones_like(depth_data[idx, 0, ...])
            if use_intensity:
                weights = intensity_data[idx, 0, ...]
            # spad_rescaled = preprocess_spad_sid_gmm(spad_rescaled, sid_obj, use_squared_falloff, dc_count > 0.)
            if dc_count > 0.:
                spad_rescaled = remove_dc_from_spad(
                    spad_rescaled,
                    sid_obj.sid_bin_edges,
                    sid_obj.sid_bin_values[:-2]**2,
                    lam=1e-1 if spad_config["use_poisson"] else 1e-1,
                    eps_rel=1e-5)
                # spad_rescaled = remove_dc_from_spad_poisson(spad_rescaled,
                #                                        sid_obj.sid_bin_edges,
                #                                        lam=lam)
                # spad = remove_dc_from_spad_ambient_estimate(spad,
                #                                             min_depth, max_depth,
                #                                             global_min_depth=np.min(depth_data),
                #                                             n_std=n_std)
                # print(spad[:10])
                # print(spad)

            if use_squared_falloff:
                spad_rescaled = spad_rescaled * sid_obj.sid_bin_values[:-2]**2
                # bin_edges = np.linspace(min_depth, max_depth, len(spad) + 1)
                # bin_values = (bin_edges[1:] + bin_edges[:-1])/2
                # spad = spad * bin_values ** 2
            # spad_rescaled = rescale_bins(spad, min_depth, max_depth, sid_obj)
            pred, _ = image_histogram_match(depth_data[idx, 0, ...],
                                            spad_rescaled, weights, sid_obj)
            # break
            # Calculate metrics
            gt = dataset[idx]["depth_cropped"].unsqueeze(0)
            # print(gt.dtype)
            # print(pred.shape)
            # print(pred[20:30, 20:30])

            pred_metrics = get_depth_metrics(
                torch.from_numpy(pred).unsqueeze(0).unsqueeze(0).float(), gt,
                torch.ones_like(gt))

            for j, metric_name in enumerate(metric_list[:-1]):
                metrics[i, j] = pred_metrics[metric_name]

            metrics[i, -1] = np.size(pred)
            # Option to save outputs:
            if save_outputs:
                outputs.append(pred)
            print("\tAvg RMSE = {}".format(
                np.mean(metrics[:i + 1, metric_list.index("rmse")])))

        if save_outputs:
            np.save(
                os.path.join(output_dir,
                             "densedepth_{}_outputs.npy".format(hyper_string)),
                np.array(outputs))

        # Save metrics using pandas
        metrics_df = pd.DataFrame(data=metrics,
                                  index=entry_list,
                                  columns=metric_list)
        metrics_df.to_pickle(path=os.path.join(
            output_dir, "densedepth_{}_metrics.pkl".format(hyper_string)))
        # Compute weighted averages:
        average_metrics = np.average(metrics_df.ix[:, :-1],
                                     weights=metrics_df.weight,
                                     axis=0)
        average_df = pd.Series(data=average_metrics, index=metric_list[:-1])
        average_df.to_csv(os.path.join(
            output_dir, "densedepth_{}_avg_metrics.csv".format(hyper_string)),
                          header=True)
        print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
            'd1', 'd2', 'd3', 'rel', 'rmse', 'log_10'))
        print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".
              format(average_metrics[0], average_metrics[1],
                     average_metrics[2], average_metrics[3],
                     average_metrics[4], average_metrics[6]))

        print("wrote results to {} ({})".format(output_dir, hyper_string))

    else:
        input_unbatched = dataset.get_item_by_id(entry)
        # for key in ["rgb", "albedo", "rawdepth", "spad", "mask", "rawdepth_orig", "mask_orig", "albedo_orig"]:
        #     input_[key] = input_[key].unsqueeze(0)
        from torch.utils.data._utils.collate import default_collate

        data = default_collate([input_unbatched])

        # Checks
        entry = data["entry"][0]
        i = int(entry)
        entry = entry if isinstance(entry, str) else entry.item()
        print("Evaluating {}[{}]".format(dataset_type, i))
        # Rescale SPAD
        spad = spad_data[i, ...]
        spad_rescaled = rescale_bins(spad, min_depth, max_depth, sid_obj)
        print("spad_rescaled", spad_rescaled)
        weights = np.ones_like(depth_data[i, 0, ...])
        if use_intensity:
            weights = intensity_data[i, 0, ...]
        # spad_rescaled = preprocess_spad_sid_gmm(spad_rescaled, sid_obj, use_squared_falloff, dc_count > 0.)
        # spad_rescaled = preprocess_spad_sid(spad_rescaled, sid_obj, use_squared_falloff, dc_count > 0.
        #                                     )

        if dc_count > 0.:
            spad_rescaled = remove_dc_from_spad(
                spad_rescaled,
                sid_obj.sid_bin_edges,
                sid_obj.sid_bin_values[:-2]**2,
                lam=1e1 if use_poisson else 1e-1,
                eps_rel=1e-5)
        if use_squared_falloff:
            spad_rescaled = spad_rescaled * sid_obj.sid_bin_values[:-2]**2
        # print(spad_rescaled)
        pred, _ = image_histogram_match(depth_data[i, 0, ...], spad_rescaled,
                                        weights, sid_obj)
        # break
        # Calculate metrics
        gt = data["depth_cropped"]
        print(gt.shape)
        print(pred.shape)
        print(gt[:, :, 40, 60])
        print(depth_data[i, 0, 40, 60])
        print("before rmse: ",
              np.sqrt(np.mean((gt.numpy() - depth_data[i, 0, ...])**2)))

        before_metrics = get_depth_metrics(
            torch.from_numpy(
                depth_data[i, 0, ...]).unsqueeze(0).unsqueeze(0).float(), gt,
            torch.ones_like(gt))
        pred_metrics = get_depth_metrics(
            torch.from_numpy(pred).unsqueeze(0).unsqueeze(0).float(), gt,
            torch.ones_like(gt))
        if save_outputs:
            np.save(
                os.path.join(
                    output_dir, "densedepth_{}[{}]_{}_out.npy".format(
                        dataset_type, entry, hyper_string)), pred)

        print("before:")
        print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
            'd1', 'd2', 'd3', 'rel', 'rmse', 'log_10'))
        print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".
              format(before_metrics["delta1"], before_metrics["delta2"],
                     before_metrics["delta3"], before_metrics["rel_abs_diff"],
                     before_metrics["rmse"], before_metrics["log10"]))
        print("after:")

        print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
            'd1', 'd2', 'd3', 'rel', 'rmse', 'log_10'))
        print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".
              format(pred_metrics["delta1"], pred_metrics["delta2"],
                     pred_metrics["delta3"], pred_metrics["rel_abs_diff"],
                     pred_metrics["rmse"], pred_metrics["log10"]))
def analyze(figures_dir, data_dir, calibration_file, models, scenes,
            use_intensity, vectorized, device):
    kinect_intrinsics, spad_intrinsics, RotationOfSpad, TranslationOfSpad = extract_camera_params(
        calibration_file)
    RotationOfKinect = RotationOfSpad.T
    TranslationOfKinect = -TranslationOfSpad.dot(RotationOfSpad.T)
    for model_str, load_run in models.items():
        model = load_run["load"](None, device)
        run_model = load_run["run"]
        output_dir = os.path.join(figures_dir, model_str)
        for scene, meta in scenes.items():
            print("Running {}...".format(scene))
            offset = meta["offset"]
            bin_width_ps = meta["bin_width_ps"]
            min_r = meta["min_r"]
            max_r = meta["max_r"]
            rootdir = os.path.join(data_dir, scene)
            scenedir = os.path.join(output_dir, scene)
            safe_makedir(os.path.join(scenedir))

            bin_width_m = bin_width_ps * 3e8 / (2 * 1e12)
            min_depth_bin = np.floor(min_r / bin_width_m).astype('int')
            max_depth_bin = np.floor(max_r / bin_width_m).astype('int')
            # Compensate for z translation only
            min_depth = min_depth_bin * bin_width_m - TranslationOfSpad[2] / 1e3
            # print(TranslationOfSpad)
            max_depth = (max_depth_bin +
                         1) * bin_width_m - TranslationOfSpad[2] / 1e3
            sid_obj_init = SID(sid_bins=600,
                               alpha=min_depth,
                               beta=max_depth,
                               offset=0)
            ambient_max_depth_bin = 100

            # RGB from Kinect
            rgb, rgb_cropped, intensity, crop = load_and_crop_kinect(
                rootdir, kinect_file="kinect.mat")
            if not use_intensity:
                intensity = np.ones_like(intensity)

            # Load all the SPAD and kinect data
            spad = load_spad(os.path.join(rootdir, "spad", "data_accum.mat"))
            spad_relevant = spad[..., min_depth_bin:max_depth_bin]
            spad_single_relevant = np.sum(spad_relevant, axis=(0, 1))
            ambient_estimate = np.mean(
                spad_single_relevant[:ambient_max_depth_bin])
            np.save(os.path.join(scenedir, "spad_single_relevant.npy"),
                    spad_single_relevant)

            # Get ground truth depth
            gt_idx = np.argmax(spad[..., :max_depth_bin], axis=2)
            gt_r = signal.medfilt(np.fliplr(np.flipud(
                (gt_idx * bin_width_m).T)),
                                  kernel_size=5)
            mask = (gt_r >= min_depth).astype('float').squeeze()
            gt_z = r_to_z(gt_r, spad_intrinsics["FocalLength"])
            gt_z = undistort_img(gt_z, **spad_intrinsics)
            mask = np.round(undistort_img(mask, **spad_intrinsics))
            # Nearest neighbor upsampling to reduce holes in output
            scale_factor = 2
            gt_z_up = cv2.resize(gt_z,
                                 dsize=(scale_factor * gt_z.shape[0],
                                        scale_factor * gt_z.shape[1]),
                                 interpolation=cv2.INTER_NEAREST)
            mask_up = cv2.resize(mask,
                                 dsize=(scale_factor * mask.shape[0],
                                        scale_factor * mask.shape[1]),
                                 interpolation=cv2.INTER_NEAREST)

            # Project GT depth and mask to RGB image coordinates and crop it.
            gt_z_proj, mask_proj = project_depth(
                gt_z_up, mask_up, (rgb.shape[0], rgb.shape[1]),
                spad_intrinsics["FocalLength"] * scale_factor,
                kinect_intrinsics["FocalLength"],
                spad_intrinsics["PrincipalPoint"] * scale_factor,
                kinect_intrinsics["PrincipalPoint"], RotationOfKinect,
                TranslationOfKinect / 1e3)
            gt_z_proj_crop = gt_z_proj[crop[0] + offset[0]:crop[1] + offset[0],
                                       crop[2] + offset[1]:crop[3] + offset[1]]
            gt_z_proj_crop = signal.medfilt(gt_z_proj_crop, kernel_size=5)
            mask_proj_crop = (gt_z_proj_crop >=
                              min_depth).astype('float').squeeze()

            # print("gt_z_proj_crop range:")
            # print(np.min(gt_z_proj_crop))
            # print(np.max(gt_z_proj_crop))

            # Process SPAD
            spad_sid, sid_obj_pred, spad_denoised, spad_corrected = \
                preprocess_spad(spad_single_relevant, ambient_estimate, min_depth, max_depth, sid_obj_init)
            np.save(os.path.join(scenedir, "spad_denoised.npy"), spad_denoised)
            np.save(os.path.join(scenedir, "spad_corrected.npy"),
                    spad_corrected)
            np.save(os.path.join(scenedir, "spad_sid.npy"), spad_sid)

            # Initialize with CNN
            z_init = run_model(model,
                               rgb_cropped,
                               depth_range=(min_depth, max_depth),
                               device=device)
            print("min(z_init):", np.min(z_init))
            print("max(z_init):", np.max(z_init))
            r_init = z_to_r(z_init, kinect_intrinsics["FocalLength"])

            print("min(r_init):", np.min(r_init))
            print("max(r_init):", np.max(r_init))

            # Histogram Match
            weights = intensity
            # weights = np.ones_like(r_init)
            # r_pred, t = image_histogram_match(r_init, spad_sid, weights, sid_obj)
            r_pred, t = image_histogram_match_variable_bin(
                r_init, spad_sid, weights, sid_obj_init, sid_obj_pred,
                vectorized)
            z_pred = r_to_z(r_pred, kinect_intrinsics["FocalLength"])

            # print("z_pred range")
            # print(np.min(z_pred))
            # print(np.max(z_pred))

            # Save histograms for later inspection
            intermediates = {
                "init_index": t[0],
                "init_hist": t[1],
                "pred_index": t[2],
                "pred_hist": t[3],
                "T_count": t[4]
            }
            np.save(os.path.join(scenedir, "intermediates.npy"), intermediates)
            # Save processed SPAD data
            bin_edges = np.linspace(min_depth, max_depth,
                                    len(spad_single_relevant) + 1)
            bin_values = (bin_edges[1:] + bin_edges[:-1]) / 2
            spad_metadata = {
                "ambient_estimate": ambient_estimate,
                "init_bin_edges": bin_edges,
                "init_bin_values": bin_values,
                "init_sid_bin_edges": sid_obj_init.sid_bin_edges,
                "init_sid_bin_values": sid_obj_init.sid_bin_values,
                "pred_sid_bin_edges": sid_obj_pred.sid_bin_edges,
                "pred_sid_bin_values": sid_obj_pred.sid_bin_values
            }
            np.save(os.path.join(scenedir, "spad_metadata.npy"), spad_metadata)

            # Mean Match
            med_bin = get_hist_med(spad_sid)
            hist_med = sid_obj_init.sid_bin_values[med_bin.astype('int')]
            r_med_scaled = np.clip(r_init * hist_med / np.median(r_init),
                                   a_min=min_depth,
                                   a_max=max_depth)
            z_med_scaled = r_to_z(r_med_scaled,
                                  kinect_intrinsics["FocalLength"])

            min_max = {}
            for k, img in zip(["gt_r", "r_init", "r_pred", "r_med_scaled"],
                              [gt_r, r_init, r_pred, r_med_scaled]):
                min_max[k] = (np.min(img), np.max(img))
            for k, img in zip([
                    "gt_z", "z_init", "z_pred", "z_med_scaled", "gt_z_proj",
                    "gt_z_proj_crop"
            ], [gt_z, z_init, z_pred, z_med_scaled, gt_z_proj, gt_z_proj_crop
                ]):
                min_max[k] = (np.min(img), np.max(img))
            np.save(os.path.join(scenedir, "mins_and_maxes.npy"), min_max)

            # Save to figures
            print("Saving figures...")
            # spad_single_relevant w/ ambient estimate
            plt.figure()
            plt.bar(range(len(spad_single_relevant)),
                    spad_single_relevant,
                    log=True)
            plt.title("spad_single_relevant".format(scene))
            plt.axhline(y=ambient_estimate, color='r', linewidth=0.5)
            plt.tight_layout()
            plt.savefig(os.path.join(scenedir, "spad_single_relevant.pdf"))
            # gt_r and gt_z and gt_z_proj and gt_z_proj_crop and masks
            depth_imwrite(gt_r, os.path.join(scenedir, "gt_r"))
            depth_imwrite(gt_z, os.path.join(scenedir, "gt_z"))
            depth_imwrite(gt_z_proj, os.path.join(scenedir, "gt_z_proj"))
            depth_imwrite(gt_z_proj_crop,
                          os.path.join(scenedir, "gt_z_proj_crop"))
            depth_imwrite(mask, os.path.join(scenedir, "mask"))
            depth_imwrite(mask_proj, os.path.join(scenedir, "mask_proj"))
            depth_imwrite(mask_proj_crop,
                          os.path.join(scenedir, "mask_proj_crop"))
            depth_imwrite(intensity, os.path.join(scenedir, "intensity"))
            np.save(os.path.join(scenedir, "crop.npy"), crop)
            # spad_sid after preprocessing
            plt.figure()
            plt.bar(range(len(spad_sid)), spad_sid, log=True)
            plt.title("spad_sid")
            plt.tight_layout()
            plt.savefig(os.path.join(scenedir, "spad_sid.pdf"))
            # rgb, rgb_cropped, intensity
            cv2.imwrite(os.path.join(scenedir, "rgb.png"),
                        cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR))
            cv2.imwrite(os.path.join(scenedir, "rgb_cropped.png"),
                        cv2.cvtColor(rgb_cropped, cv2.COLOR_RGB2BGR))
            # r_init, z_init, diff_maps
            depth_imwrite(r_init, os.path.join(scenedir, "r_init"))
            depth_imwrite(z_init, os.path.join(scenedir, "z_init"))
            # r_pred, z_pred, diff_maps
            depth_imwrite(r_pred, os.path.join(scenedir, "r_pred"))
            depth_imwrite(z_pred, os.path.join(scenedir, "z_pred"))
            # r_med_scaled, z_med_scaled, diff_maps
            depth_imwrite(r_med_scaled, os.path.join(scenedir, "r_med_scaled"))
            depth_imwrite(z_med_scaled, os.path.join(scenedir, "z_med_scaled"))
            plt.close('all')

            # Compute metrics
            print("Computing error metrics...")
            # z_init
            # z_init_resized = cv2.resize(z_init, gt_z.shape)
            init_metrics = get_depth_metrics(
                torch.from_numpy(z_init).float(),
                torch.from_numpy(gt_z_proj_crop).float(),
                torch.from_numpy(mask_proj_crop).float())
            np.save(os.path.join(scenedir, "init_metrics.npy"), init_metrics)
            # z_pred
            # z_pred_resized = cv2.resize(z_pred, gt_z.shape)
            pred_metrics = get_depth_metrics(
                torch.from_numpy(z_pred).float(),
                torch.from_numpy(gt_z_proj_crop).float(),
                torch.from_numpy(mask_proj_crop).float())
            np.save(os.path.join(scenedir, "pred_metrics.npy"), pred_metrics)

            # z_med_scaled
            # z_med_scaled_resized = cv2.resize(z_med_scaled, gt_z.shape)
            med_scaled_metrics = get_depth_metrics(
                torch.from_numpy(z_med_scaled).float(),
                torch.from_numpy(gt_z_proj_crop).float(),
                torch.from_numpy(mask_proj_crop).float())
            np.save(os.path.join(scenedir, "med_scaled_metrics.npy"),
                    med_scaled_metrics)
def run(dataset_type, spad_file, dorn_depth_file, hyper_string, sid_bins,
        alpha, beta, offset, lam, eps_rel, entry, save_outputs, small_run,
        output_dir):
    # Load all the data:
    print("Loading SPAD data from {}".format(spad_file))
    spad_dict = np.load(spad_file, allow_pickle=True).item()
    spad_data = spad_dict["spad"]
    intensity_data = spad_dict["intensity"]
    spad_config = spad_dict["config"]
    print("Loading depth data from {}".format(dorn_depth_file))
    depth_data = np.load(dorn_depth_file)
    dataset = load_data(channels_first=True, dataset_type=dataset_type)

    # Read SPAD config and determine proper course of action
    dc_count = spad_config["dc_count"]
    ambient = spad_config["dc_count"] / spad_config["spad_bins"]
    use_intensity = spad_config["use_intensity"]
    use_squared_falloff = spad_config["use_squared_falloff"]
    min_depth = spad_config["min_depth"]
    max_depth = spad_config["max_depth"]

    print("dc_count: ", dc_count)
    print("use_intensity: ", use_intensity)
    print("use_squared_falloff:", use_squared_falloff)

    print("spad_data.shape", spad_data.shape)
    print("depth_data.shape", depth_data.shape)
    print("intensity_data.shape", intensity_data.shape)

    sid_obj = SID(sid_bins, alpha, beta, offset)

    if entry is None:
        metric_list = [
            "delta1", "delta2", "delta3", "rel_abs_diff", "rmse", "mse",
            "log10", "weight"
        ]
        metrics = np.zeros(
            (len(dataset) if not small_run else small_run, len(metric_list)))
        entry_list = []
        outputs = []
        for i in range(depth_data.shape[0]):
            if small_run and i == small_run:
                break
            entry_list.append(i)

            print("Evaluating {}[{}]".format(dataset_type, i))
            spad = spad_data[i, ...]
            weights = np.ones_like(depth_data[i, 0, ...])
            if use_intensity:
                weights = intensity_data[i, 0, ...]
            if dc_count > 0.:
                spad = remove_dc_from_spad_edge(spad,
                                                ambient=ambient,
                                                grad_th=3 * ambient)
            if use_squared_falloff:
                bin_edges = np.linspace(min_depth, max_depth, len(spad) + 1)
                bin_values = (bin_edges[1:] + bin_edges[:-1]) / 2
                spad = spad * bin_values**2
            spad_rescaled = rescale_bins(spad, min_depth, max_depth, sid_obj)
            pred, _ = image_histogram_match(depth_data[i, 0, ...],
                                            spad_rescaled, weights, sid_obj)
            # break
            # Calculate metrics
            gt = dataset[i]["depth_cropped"].unsqueeze(0)
            # print(gt.dtype)
            # print(pred.shape)

            pred_metrics = get_depth_metrics(
                torch.from_numpy(pred).unsqueeze(0).unsqueeze(0).float(), gt,
                torch.ones_like(gt))

            for j, metric_name in enumerate(metric_list[:-1]):
                metrics[i, j] = pred_metrics[metric_name]

            metrics[i, -1] = np.size(pred)
            # Option to save outputs:
            if save_outputs:
                outputs.append(pred)

        if save_outputs:
            np.save(
                os.path.join(output_dir,
                             "dorn_{}_outputs.npy".format(hyper_string)),
                np.array(outputs))

        # Save metrics using pandas
        metrics_df = pd.DataFrame(data=metrics,
                                  index=entry_list,
                                  columns=metric_list)
        metrics_df.to_pickle(path=os.path.join(
            output_dir, "dorn_{}_metrics.pkl".format(hyper_string)))
        # Compute weighted averages:
        average_metrics = np.average(metrics_df.ix[:, :-1],
                                     weights=metrics_df.weight,
                                     axis=0)
        average_df = pd.Series(data=average_metrics, index=metric_list[:-1])
        average_df.to_csv(os.path.join(
            output_dir, "dorn_{}_avg_metrics.csv".format(hyper_string)),
                          header=True)
        print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
            'd1', 'd2', 'd3', 'rel', 'rms', 'log_10'))
        print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".
              format(average_metrics[0], average_metrics[1],
                     average_metrics[2], average_metrics[3],
                     average_metrics[4], average_metrics[6]))

        print("wrote results to {} ({})".format(output_dir, hyper_string))

    else:
        input_unbatched = dataset.get_item_by_id(entry)
        # for key in ["rgb", "albedo", "rawdepth", "spad", "mask", "rawdepth_orig", "mask_orig", "albedo_orig"]:
        #     input_[key] = input_[key].unsqueeze(0)
        from torch.utils.data._utils.collate import default_collate

        data = default_collate([input_unbatched])

        # Checks
        entry = data["entry"][0]
        i = int(entry)
        entry = entry if isinstance(entry, str) else entry.item()
        print("Evaluating {}[{}]".format(dataset_type, i))
        # Rescale SPAD
        spad_rescaled = rescale_bins(spad_data[i, ...], min_depth, max_depth,
                                     sid_obj)
        weights = np.ones_like(depth_data[i, 0, ...])
        if use_intensity:
            weights = intensity_data[i, 0, ...]
        spad_rescaled = preprocess_spad(spad_rescaled,
                                        sid_obj,
                                        use_squared_falloff,
                                        dc_count > 0.,
                                        lam=lam,
                                        eps_rel=eps_rel)

        pred, _ = image_histogram_match(depth_data[i, 0, ...], spad_rescaled,
                                        weights, sid_obj)
        # break
        # Calculate metrics
        gt = data["depth_cropped"]
        print(gt.shape)
        print(pred.shape)

        pred_metrics = get_depth_metrics(
            torch.from_numpy(pred).unsqueeze(0).unsqueeze(0), gt,
            torch.ones_like(gt))
        if save_outputs:
            np.save(
                os.path.join(
                    output_dir,
                    "dorn_{}[{}]_{}_out.npy".format(dataset_type, entry,
                                                    hyper_string)))
        print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
            'd1', 'd2', 'd3', 'rel', 'rms', 'log_10'))
        print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".
              format(pred_metrics["delta1"], pred_metrics["delta2"],
                     pred_metrics["delta3"], pred_metrics["rel_abs_diff"],
                     pred_metrics["rms"], pred_metrics["log10"]))
            cp.sum_squares(spad_equalized[i, :] - (x + z)) +
            lam * cp.norm(x, 1))
        constr = [x >= 0, z >= 0]
        prob = cp.Problem(obj, constr)
        prob.solve(solver=cp.OSQP, eps_abs=eps)
        #         signal_hist = x.value
        denoised_spad[i, :] = x.value * bin_widths
    return denoised_spad


if __name__ == "__main__":
    min_depth = 0.
    max_depth = 10.

    sid_bins = 68  # Number of bins (network outputs 2x this number of channels)
    bin_edges = np.array(range(sid_bins + 1)).astype(np.float32)
    dorn_decode = np.exp((bin_edges - 1) / 25 - 0.36)
    d0 = dorn_decode[0]
    d1 = dorn_decode[1]
    # Algebra stuff to make the depth bins work out exactly like in the
    # original DORN code.
    alpha = (2 * d0**2) / (d1 + d0)
    beta = alpha * np.exp(sid_bins * np.log(2 * d0 / alpha - 1))
    del bin_edges, dorn_decode, d0, d1
    offset = 0.

    from models.data.data_utils.sid_utils import SID
    sid_obj = SID(sid_bins, alpha, beta, offset)
    layer = get_rescale_layer(1024, min_depth, max_depth, sid_obj)
    print(layer.weight.data[:, 0])
Exemple #10
0
def load_data(train_file, train_dir,
              val_file, val_dir,
              test_file, test_dir,
              min_depth, max_depth, normalization,
              sid_bins, alpha, beta, offset,
              blacklist_file,
              spad_config):
    """Generates training and validation datasets from
    text files and directories. Sets up datasets with transforms.py.
    *_file - string - a text file containing info for DepthDataset to load the images
    *_dir - string - the folder containing the images to load
    min_depth - the minimum depth for this dataset
    max_depth - the maximum depth for this dataset
    normalization - The type of normalization to use.
    sid_bins - the number of Spacing Increasing Discretization bins to add.
    blacklist_file - string - a text file listing, on each line, an image_id of an image to exclude
                              from the dataset.

    test_loader - bool - whether or not to test the loader and not set the dataset-wide mean and
                         variance.

    Returns
    -------
    train, val, test - torch.data_utils.data.Dataset objects containing the relevant splits
    """
    # print(spad_config)
    train = NYUDepthv2Dataset(train_file, train_dir, transform=None,
                              file_types=["rgb", "rawdepth"],
                              min_depth=min_depth, max_depth=max_depth,
                              blacklist_file=blacklist_file)

    train.rgb_mean, train.rgb_var = train.get_mean_and_var()

    # Transform:
    # Size is set to (353, 257) to conform to DORN conventions
    # If normalization == "dorn":
    # Mean is set to np.array([[[103.0626, 115.9029, 123.1516]]]).astype(np.float32) to conform to DORN conventions
    # Var is set to np.ones((1,1,3)) to conform to DORN conventions
    if normalization == "dorn":
        # Use normalization as in the github code for DORN.
        print("Using DORN normalization.")
        transform_mean = np.array([[[103.0626, 115.9029, 123.1516]]]).astype(np.float32)
        transform_var = np.ones((1, 1, 3))
    elif normalization == "none":
        print("No normalization.")
        transform_mean = np.zeros((1, 1, 3))
        transform_var = np.ones((1, 1, 3))
    else:
        transform_mean = train.rgb_mean
        transform_var = train.rgb_var

    train_transform = transforms.Compose([
        AddDepthMask(min_depth, max_depth, "rawdepth"),
        Save(["rgb", "mask", "rawdepth"], "_orig"),
        Normalize(transform_mean, transform_var, key="rgb"),
        ResizeAll((353, 257), keys=["rgb", "rawdepth"]),
        RandomHorizontalFlipAll(flip_prob=0.5, keys=["rgb", "rawdepth"]),
        AddDepthMask(min_depth, max_depth, "rawdepth"), # "mask"
        AddSIDDepth(sid_bins, alpha, beta, offset, "rawdepth"), # "rawdepth_sid"  "rawdepth_sid_index"
        SimulateSpadIntensity("rawdepth", "rgb", "mask", "spad", min_depth, max_depth,
                     spad_config["spad_bins"],
                     spad_config["photon_count"],
                     spad_config["dc_count"],
                     spad_config["fwhm_ps"],
                     spad_config["use_intensity"],
                     spad_config["use_squared_falloff"],
                     sid_obj=SID(sid_bins, alpha, beta, offset)),
        ToTensorAll(keys=["rgb", "rgb_orig", "rawdepth", "rawdepth_orig",
                          "rawdepth_sid", "rawdepth_sid_index", "mask", "mask_orig", "spad"])
        ]
    )

    val_transform = transforms.Compose([
        AddDepthMask(min_depth, max_depth, "rawdepth"),
        Save(["rgb", "depth", "mask", "rawdepth"], "_orig"),
        Normalize(transform_mean, transform_var, key="rgb"),
        ResizeAll((353, 257), keys=["rgb", "depth", "rawdepth"]),
        AddDepthMask(min_depth, max_depth, "rawdepth"),
        AddSIDDepth(sid_bins, alpha, beta, offset, "rawdepth"),
        SimulateSpadIntensity("depth", "rgb", "mask", "spad", min_depth, max_depth,
                     spad_config["spad_bins"],
                     spad_config["photon_count"],
                     spad_config["dc_count"],
                     spad_config["fwhm_ps"],
                     spad_config["use_intensity"],
                     spad_config["use_squared_falloff"],
                     sid_obj=SID(sid_bins, alpha, beta, offset)),
        ToTensorAll(keys=["rgb", "rgb_orig", "depth", "depth_orig", "rawdepth", "rawdepth_orig",
                          "rawdepth_sid", "rawdepth_sid_index", "mask", "mask_orig", "spad"])
        ]
    )

    test_transform = transforms.Compose([
        AddDepthMask(min_depth, max_depth, "rawdepth"),
        Save(["rgb", "depth", "mask", "rawdepth"], "_orig"),
        Normalize(transform_mean, transform_var, key="rgb"),
        ResizeAll((353, 257), keys=["rgb", "depth", "rawdepth"]),
        AddDepthMask(min_depth, max_depth, "rawdepth"),
        AddSIDDepth(sid_bins, alpha, beta, offset, "rawdepth"),
        SimulateSpadIntensity("depth", "rgb", "mask", "spad", min_depth, max_depth,
                     spad_config["spad_bins"],
                     spad_config["photon_count"],
                     spad_config["dc_count"],
                     spad_config["fwhm_ps"],
                     spad_config["use_intensity"],
                     spad_config["use_squared_falloff"],
                     sid_obj=SID(sid_bins, alpha, beta, offset)),
        ToTensorAll(keys=["rgb", "rgb_orig", "depth", "depth_orig", "rawdepth", "rawdepth_orig",
                          "rawdepth_sid", "rawdepth_sid_index", "mask", "mask_orig", "spad"])
        ]
    )
    train.transform = train_transform
    print("Loaded training dataset from {} with size {}.".format(train_file, len(train)))
    val = None
    if val_file is not None:
        val = NYUDepthv2Dataset(val_file, val_dir, transform=val_transform,
                                file_types = ["rgb", "depth", "rawdepth"],
                                min_depth=min_depth, max_depth=max_depth)
        val.rgb_mean, val.rgb_var = train.rgb_mean, train.rgb_var
        print("Loaded val dataset from {} with size {}.".format(val_file, len(val)))
    test = None
    if test_file is not None:
        test = NYUDepthv2Dataset(test_file, test_dir, transform=test_transform,
                                 file_types = ["rgb", "depth", "rawdepth"],
                                 min_depth=min_depth, max_depth=max_depth)
        test.rgb_mean, test.rgb_var = train.rgb_mean, train.rgb_var
        print("Loaded test dataset from {} with size {}.".format(test_file, len(test)))

    return train, val, test
 # print(test.transform[6])
 depth_img_np = depth_img.numpy()
 simulate_spad_transform = SimulateSpad("rawdepth",
                                        "albedo",
                                        "mask",
                                        "spad",
                                        data_config["min_depth"],
                                        data_config["max_depth"],
                                        spad_config["spad_bins"],
                                        spad_config["photon_count"],
                                        spad_config["dc_count"],
                                        spad_config["fwhm_ps"],
                                        spad_config["use_albedo"],
                                        spad_config["use_squared_falloff"],
                                        sid_obj=SID(data_config["sid_bins"],
                                                    data_config["alpha"],
                                                    data_config["beta"],
                                                    data_config["offset"]))
 print(depth_pred.shape)
 print(input_["albedo"].shape)
 print(input_["mask"].shape)
 sample = {
     "rawdepth": depth_pred.numpy().squeeze(0).transpose(1, 2, 0),
     "albedo": input_["albedo"].numpy().squeeze(0).transpose(1, 2, 0),
     "mask": np.ones_like(depth_pred.numpy()).squeeze(0).transpose(1, 2, 0)
 }
 depth_img_spad = simulate_spad_transform(sample)
 print(depth_img_spad["spad"])
 plt.figure()
 plt.plot(depth_img_spad["spad"])
 plt.title("Predicted depth histogram")
 plt.draw()