def run(dataset_type, entry, save_outputs, small_run, input_file, output_dir):
    dataset = load_data(channels_first=False, dataset_type=dataset_type)
    cnn_data = np.load(input_file)

    if entry is None:
        metric_list = [
            "delta1", "delta2", "delta3", "rel_abs_diff", "rmse", "mse",
            "log10", "weight"
        ]
        metrics = np.zeros(
            (len(dataset) if not small_run else small_run, len(metric_list)))
        entry_list = []
        outputs = []
        for i in range(len(dataset)):
            if small_run and i == small_run:
                break
            print("Running {}[{}]".format(dataset_type, i))
            entry_list.append(i)
            init = cnn_data[i, ...]
            gt = dataset[i]["depth_cropped"]
            pred = init * (torch.median(gt).item() / np.median(init))
            pred_metrics = get_depth_metrics(
                torch.from_numpy(pred).float(), gt, torch.ones_like(gt))
            for j, metric_name in enumerate(metric_list[:-1]):
                metrics[i, j] = pred_metrics[metric_name]

            metrics[i, -1] = torch.numel(gt)
            # Option to save outputs:
            if save_outputs:
                outputs.append(pred)

        if save_outputs:
            np.save(
                os.path.join(
                    output_dir,
                    "dorn_median_{}_outputs.npy".format(dataset_type)),
                np.concatenate(outputs, axis=0))

        # Save metrics using pandas
        metrics_df = pd.DataFrame(data=metrics,
                                  index=entry_list,
                                  columns=metric_list)
        metrics_df.to_pickle(path=os.path.join(
            output_dir, "dorn_median_{}_metrics.pkl".format(dataset_type)))
        # Compute weighted averages:
        average_metrics = np.average(metrics_df.ix[:, :-1],
                                     weights=metrics_df.weight,
                                     axis=0)
        average_df = pd.Series(data=average_metrics, index=metric_list[:-1])
        average_df.to_csv(os.path.join(
            output_dir, "dorn_median_{}_avg_metrics.csv".format(dataset_type)),
                          header=True)
        print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
            'd1', 'd2', 'd3', 'rel', 'rms', 'log_10'))
        print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".
              format(average_metrics[0], average_metrics[1],
                     average_metrics[2], average_metrics[3],
                     average_metrics[4], average_metrics[6]))
    else:
        raise NotImplementedError
Esempio n. 2
0
def run(dataset_type,
        output_dir,
        use_intensity,
        use_squared_falloff,
        dc_count,
        _config):  # The entire config dict for this experiment
    print("dataset_type: {}".format(dataset_type))
    dataset = load_data(dataset_type)
    all_spad_counts = []
    all_intensities = []
    for i in range(len(dataset)):
        print("Simulating SPAD for entry {}".format(i))
        data = default_collate([dataset[i]])
        intensity = rgb2gray(data["rgb_cropped"].numpy()/255.)
        spad_counts = simulate_spad(depth_truth=data["depth_cropped"].numpy(),
                                    intensity=intensity,
                                    mask=np.ones_like(intensity))
        all_spad_counts.append(spad_counts)
        all_intensities.append(intensity)

    output = {
        "config": _config,
        "spad": np.array(all_spad_counts),
        "intensity": np.concatenate(all_intensities),
    }

    print("saving {}_int_{}_fall_{}_dc_{}_spad.npy to {}".format(dataset_type,
                                                                 use_intensity, use_squared_falloff, dc_count,
                                                                 output_dir))
    np.save(os.path.join(output_dir, "{}_int_{}_fall_{}_dc_{}_spad.npy".format(dataset_type,
                                                                               use_intensity,
                                                                               use_squared_falloff,
                                                                               dc_count)), output)
def run(dataset_type, output_dir, use_intensity, use_squared_falloff, dc_count,
        use_jitter, lambertian, use_poisson,
        _config):  # The entire config dict for this experiment
    print("dataset_type: {}".format(dataset_type))
    dataset = load_data(channels_first=True, dataset_type=dataset_type)
    all_spad_counts = []
    all_intensities = []
    for i in range(len(dataset)):
        print("Simulating SPAD for entry {}".format(i))
        data = default_collate([dataset[i]])

        intensity = rgb2gray(data["rgb_cropped"].numpy() / 255.)
        # print(intensity.shape)
        depth_truth = data["depth_cropped"].numpy()
        spad_counts = simulate_spad_passthrough(depth_truth=depth_truth,
                                                intensity=intensity,
                                                mask=np.ones_like(depth_truth))
        all_spad_counts.append(spad_counts)
        all_intensities.append(intensity)

    output = {
        "config": _config,
        "spad": np.array(all_spad_counts),
        "intensity": np.concatenate(all_intensities),
    }

    print(
        "saving {}_int_{}_fall_{}_lamb_{}_dc_{}_jit_{}_poiss_{}_spad.npy to {}"
        .format(dataset_type, use_intensity, use_squared_falloff, lambertian,
                dc_count, use_jitter, use_poisson, output_dir))
    np.save(
        os.path.join(
            output_dir,
            "{}_int_{}_fall_{}_lamb_{}_dc_{}_jit_{}_poiss_{}_spad.npy".format(
                dataset_type, use_intensity, use_squared_falloff, lambertian,
                dc_count, use_jitter, use_poisson)), output)
def run(dataset_type, spad_file, densedepth_depth_file, hyper_string, sid_bins,
        alpha, beta, offset, intensity_ablation, vectorized, entry,
        save_outputs, small_run, output_dir):
    print("output dir: {}".format(output_dir))
    safe_makedir(output_dir)

    # Load all the data:
    print("Loading SPAD data from {}".format(spad_file))
    spad_dict = np.load(spad_file, allow_pickle=True).item()
    spad_data = spad_dict["spad"]
    intensity_data = spad_dict["intensity"]
    spad_config = spad_dict["config"]
    print("Loading depth data from {}".format(densedepth_depth_file))
    depth_data = np.load(densedepth_depth_file)
    dataset = load_data(channels_first=True, dataset_type=dataset_type)

    # Read SPAD config and determine proper course of action
    dc_count = spad_config["dc_count"]
    ambient = spad_config["dc_count"] / spad_config["spad_bins"]
    use_intensity = spad_config["use_intensity"]
    use_squared_falloff = spad_config["use_squared_falloff"]
    lambertian = spad_config["lambertian"]
    use_poisson = spad_config["use_poisson"]
    min_depth = spad_config["min_depth"]
    max_depth = spad_config["max_depth"]

    print("ambient: ", ambient)
    print("dc_count: ", dc_count)
    print("use_intensity: ", use_intensity)
    print("use_squared_falloff:", use_squared_falloff)
    print("lambertian:", lambertian)

    print("spad_data.shape", spad_data.shape)
    print("depth_data.shape", depth_data.shape)
    print("intensity_data.shape", intensity_data.shape)

    sid_obj_init = SID(sid_bins, alpha, beta, offset)

    if entry is None:
        metric_list = [
            "delta1", "delta2", "delta3", "rel_abs_diff", "rmse", "mse",
            "log10", "weight"
        ]
        metrics = np.zeros(
            (len(dataset) if not small_run else small_run, len(metric_list)))
        entry_list = []
        outputs = []
        times = []
        for i in range(depth_data.shape[0]):
            if small_run and i == small_run:
                break
            entry_list.append(i)
            print("Evaluating {}[{}]".format(dataset_type, i))
            spad = spad_data[i, ...]
            bin_edges = np.linspace(min_depth, max_depth, len(spad) + 1)
            bin_values = (bin_edges[1:] + bin_edges[:-1]) / 2
            # spad = preprocess_spad_ambient_estimate(spad, min_depth, max_depth,
            #                                             correct_falloff=use_squared_falloff,
            #                                             remove_dc= dc_count > 0.,
            #                                             global_min_depth=np.min(depth_data),
            #                                             n_std=1. if use_poisson else 0.01)
            # Rescale SPAD_data
            weights = np.ones_like(depth_data[i, 0, ...])
            # Ablation study: Turn off intensity, even if spad has been simulated with it.
            if use_intensity and not intensity_ablation:
                weights = intensity_data[i, 0, ...]

            if dc_count > 0.:
                spad = remove_dc_from_spad_edge(
                    spad,
                    ambient=ambient,
                    # grad_th=2*ambient)
                    grad_th=5 * np.sqrt(2 * ambient))
            # print(2*ambient)
            # print(5*np.sqrt(2*ambient))

            if use_squared_falloff:
                if lambertian:
                    spad = spad * bin_values**4
                else:
                    spad = spad * bin_values**2
            # Scale SID object to maximize bin utilization
            nonzeros = np.nonzero(spad)[0]
            if nonzeros.size > 0:
                min_depth_bin = np.min(nonzeros)
                max_depth_bin = np.max(nonzeros) + 1
                if max_depth_bin > len(bin_edges) - 2:
                    max_depth_bin = len(bin_edges) - 2
            else:
                min_depth_bin = 0
                max_depth_bin = len(bin_edges) - 2
            min_depth_pred = np.clip(bin_edges[min_depth_bin],
                                     a_min=1e-2,
                                     a_max=None)
            max_depth_pred = np.clip(bin_edges[max_depth_bin + 1],
                                     a_min=1e-2,
                                     a_max=None)
            # print(min_depth_pred)
            # print(max_depth_pred)
            sid_obj_pred = SID(sid_bins=sid_obj_init.sid_bins,
                               alpha=min_depth_pred,
                               beta=max_depth_pred,
                               offset=0.)
            spad_rescaled = rescale_bins(spad[min_depth_bin:max_depth_bin + 1],
                                         min_depth_pred, max_depth_pred,
                                         sid_obj_pred)
            start = process_time()
            pred, t = image_histogram_match_variable_bin(
                depth_data[i, 0, ...], spad_rescaled, weights, sid_obj_init,
                sid_obj_pred, vectorized)
            times.append(process_time() - start)
            # break
            # Calculate metrics
            gt = dataset[i]["depth_cropped"].unsqueeze(0)
            # print(gt.dtype)
            # print(pred.shape)
            # print(pred[20:30, 20:30])

            pred_metrics = get_depth_metrics(
                torch.from_numpy(pred).unsqueeze(0).unsqueeze(0).float(), gt,
                torch.ones_like(gt))

            for j, metric_name in enumerate(metric_list[:-1]):
                metrics[i, j] = pred_metrics[metric_name]

            metrics[i, -1] = np.size(pred)
            # Option to save outputs:
            if save_outputs:
                outputs.append(pred)
            print("\tAvg RMSE = {}".format(
                np.mean(metrics[:i + 1, metric_list.index("rmse")])))

        if save_outputs:
            np.save(
                os.path.join(output_dir,
                             "densedepth_{}_outputs.npy".format(hyper_string)),
                np.array(outputs))
        print("Avg Time: {}".format(np.mean(times)))
        # Save metrics using pandas
        metrics_df = pd.DataFrame(data=metrics,
                                  index=entry_list,
                                  columns=metric_list)
        metrics_df.to_pickle(path=os.path.join(
            output_dir, "densedepth_{}_metrics.pkl".format(hyper_string)))
        # Compute weighted averages:
        average_metrics = np.average(metrics_df.ix[:, :-1],
                                     weights=metrics_df.weight,
                                     axis=0)
        average_df = pd.Series(data=average_metrics, index=metric_list[:-1])
        average_df.to_csv(os.path.join(
            output_dir, "densedepth_{}_avg_metrics.csv".format(hyper_string)),
                          header=True)
        print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
            'd1', 'd2', 'd3', 'rel', 'rmse', 'log_10'))
        print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".
              format(average_metrics[0], average_metrics[1],
                     average_metrics[2], average_metrics[3],
                     average_metrics[4], average_metrics[6]))

        print("wrote results to {} ({})".format(output_dir, hyper_string))

    else:
        input_unbatched = dataset.get_item_by_id(entry)
        # for key in ["rgb", "albedo", "rawdepth", "spad", "mask", "rawdepth_orig", "mask_orig", "albedo_orig"]:
        #     input_[key] = input_[key].unsqueeze(0)
        from torch.utils.data._utils.collate import default_collate

        data = default_collate([input_unbatched])

        # Checks
        entry = data["entry"][0]
        i = int(entry)
        entry = entry if isinstance(entry, str) else entry.item()
        print("Evaluating {}[{}]".format(dataset_type, i))
        # Rescale SPAD
        spad = spad_data[i, ...]
        spad_rescaled = rescale_bins(spad, min_depth, max_depth, sid_obj)
        print("spad_rescaled", spad_rescaled)
        weights = np.ones_like(depth_data[i, 0, ...])
        if use_intensity:
            weights = intensity_data[i, 0, ...]
        # spad_rescaled = preprocess_spad_sid_gmm(spad_rescaled, sid_obj, use_squared_falloff, dc_count > 0.)
        # spad_rescaled = preprocess_spad_sid(spad_rescaled, sid_obj, use_squared_falloff, dc_count > 0.
        #                                     )

        if dc_count > 0.:
            spad_rescaled = remove_dc_from_spad(
                spad_rescaled,
                sid_obj.sid_bin_edges,
                sid_obj.sid_bin_values[:-2]**2,
                lam=1e1 if use_poisson else 1e-1,
                eps_rel=1e-5)
        if use_squared_falloff:
            spad_rescaled = spad_rescaled * sid_obj.sid_bin_values[:-2]**2
        # print(spad_rescaled)
        pred, _ = image_histogram_match(depth_data[i, 0, ...], spad_rescaled,
                                        weights, sid_obj)
        # break
        # Calculate metrics
        gt = data["depth_cropped"]
        print(gt.shape)
        print(pred.shape)
        print(gt[:, :, 40, 60])
        print(depth_data[i, 0, 40, 60])
        print("before rmse: ",
              np.sqrt(np.mean((gt.numpy() - depth_data[i, 0, ...])**2)))

        before_metrics = get_depth_metrics(
            torch.from_numpy(
                depth_data[i, 0, ...]).unsqueeze(0).unsqueeze(0).float(), gt,
            torch.ones_like(gt))
        pred_metrics = get_depth_metrics(
            torch.from_numpy(pred).unsqueeze(0).unsqueeze(0).float(), gt,
            torch.ones_like(gt))
        if save_outputs:
            np.save(
                os.path.join(
                    output_dir, "densedepth_{}[{}]_{}_out.npy".format(
                        dataset_type, entry, hyper_string)), pred)

        print("before:")
        print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
            'd1', 'd2', 'd3', 'rel', 'rmse', 'log_10'))
        print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".
              format(before_metrics["delta1"], before_metrics["delta2"],
                     before_metrics["delta3"], before_metrics["rel_abs_diff"],
                     before_metrics["rmse"], before_metrics["log10"]))
        print("after:")

        print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
            'd1', 'd2', 'd3', 'rel', 'rmse', 'log_10'))
        print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".
              format(pred_metrics["delta1"], pred_metrics["delta2"],
                     pred_metrics["delta3"], pred_metrics["rel_abs_diff"],
                     pred_metrics["rmse"], pred_metrics["log10"]))
Esempio n. 5
0
def main(dataset_type, entry, save_outputs, output_dir, seed, small_run,
         device):

    # Load the data
    dataset = load_data(dataset_type=dataset_type)

    # Load the model
    model = DORN()
    model.eval()
    model.to(device)

    init_randomness(seed)

    if entry is None:
        dataloader = DataLoader(
            dataset,
            batch_size=1,
            shuffle=False,
            num_workers=0,  # needs to be 0 to not crash autograd profiler.
            pin_memory=True)
        # if eval_config["save_outputs"]:

        with torch.no_grad():
            metric_list = [
                "delta1", "delta2", "delta3", "rel_abs_diff", "rmse", "mse",
                "log10", "weight"
            ]
            metrics = np.zeros((len(dataset) if not small_run else small_run,
                                len(metric_list)))
            entry_list = []
            outputs = []
            for i, data in enumerate(dataloader):
                # TESTING
                if small_run and i == small_run:
                    break
                entry = data["entry"][0]
                entry = entry if isinstance(entry, str) else entry.item()
                entry_list.append(entry)
                print("Evaluating {}".format(data["entry"][0]))
                # pred, pred_metrics = model.evaluate(data, device)
                pred, pred_metrics, pred_weight = model.evaluate(
                    data["bgr"].to(device), data["bgr_orig"].to(device),
                    data["depth_cropped"].to(device),
                    torch.ones_like(data["depth_cropped"]).to(device))
                for j, metric_name in enumerate(metric_list[:-1]):
                    metrics[i, j] = pred_metrics[metric_name]

                metrics[i, -1] = pred_weight
                # Option to save outputs:
                if save_outputs:
                    outputs.append(pred.cpu().numpy())

            if save_outputs:
                np.save(
                    os.path.join(output_dir,
                                 "dorn_{}_outputs.npy".format(dataset_type)),
                    np.concatenate(outputs, axis=0))

            # Save metrics using pandas
            metrics_df = pd.DataFrame(data=metrics,
                                      index=entry_list,
                                      columns=metric_list)
            metrics_df.to_pickle(path=os.path.join(
                output_dir, "dorn_{}_metrics.pkl".format(dataset_type)))
            # Compute weighted averages:
            average_metrics = np.average(metrics_df.ix[:, :-1],
                                         weights=metrics_df.weight,
                                         axis=0)
            average_df = pd.Series(data=average_metrics,
                                   index=metric_list[:-1])
            average_df.to_csv(os.path.join(
                output_dir, "dorn_{}_avg_metrics.csv".format(dataset_type)),
                              header=True)
            print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
                'd1', 'd2', 'd3', 'rel', 'rms', 'log_10'))
            print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".
                  format(average_metrics[0], average_metrics[1],
                         average_metrics[2], average_metrics[3],
                         average_metrics[4], average_metrics[6]))
        print("wrote results to {}".format(output_dir))

    else:
        input_unbatched = dataset.get_item_by_id(entry)
        # for key in ["rgb", "albedo", "rawdepth", "spad", "mask", "rawdepth_orig", "mask_orig", "albedo_orig"]:
        #     input_[key] = input_[key].unsqueeze(0)
        from torch.utils.data._utils.collate import default_collate
        data = default_collate([input_unbatched])

        # Checks
        entry = data["entry"][0]
        entry = entry if isinstance(entry, str) else entry.item()
        print("Entry: {}".format(entry))
        # print("remove_dc: ", model.remove_dc)
        # print("use_intensity: ", model.use_intensity)
        # print("use_squared_falloff: ", model.use_squared_falloff)
        pred, pred_metrics, pred_weight = model.evaluate(
            data["bgr"].to(device), data["bgr_orig"].to(device),
            data["depth_cropped"].to(device),
            torch.ones_like(data["depth_cropped"]).to(device))
        if save_outputs:
            np.save(
                os.path.join(output_dir,
                             "{}_{}_out.npy".format(dataset_type, entry)))
        print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
            'd1', 'd2', 'd3', 'rel', 'rms', 'log_10'))
        print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".
              format(pred_metrics["delta1"], pred_metrics["delta2"],
                     pred_metrics["delta3"], pred_metrics["rel_abs_diff"],
                     pred_metrics["rms"], pred_metrics["log10"]))
def run(dataset_type, spad_file, densedepth_depth_file, hyper_string, sid_bins,
        alpha, beta, offset, lam, eps_rel, n_std, entry, save_outputs,
        small_run, subsampling, output_dir):
    # Load all the data:
    print("Loading SPAD data from {}".format(spad_file))
    spad_dict = np.load(spad_file, allow_pickle=True).item()
    spad_data = spad_dict["spad"]
    intensity_data = spad_dict["intensity"]
    spad_config = spad_dict["config"]
    print("Loading depth data from {}".format(densedepth_depth_file))
    depth_data = np.load(densedepth_depth_file)
    dataset = load_data(channels_first=True, dataset_type=dataset_type)

    # Read SPAD config and determine proper course of action
    dc_count = spad_config["dc_count"]
    use_intensity = spad_config["use_intensity"]
    use_squared_falloff = spad_config["use_squared_falloff"]
    use_poisson = spad_config["use_poisson"]
    min_depth = spad_config["min_depth"]
    max_depth = spad_config["max_depth"]

    print("dc_count: ", dc_count)
    print("use_intensity: ", use_intensity)
    print("use_squared_falloff:", use_squared_falloff)

    print("spad_data.shape", spad_data.shape)
    print("depth_data.shape", depth_data.shape)
    print("intensity_data.shape", intensity_data.shape)

    sid_obj = SID(sid_bins, alpha, beta, offset)

    if entry is None:
        metric_list = [
            "delta1", "delta2", "delta3", "rel_abs_diff", "rmse", "mse",
            "log10", "weight"
        ]
        print(len(dataset) // subsampling)
        metrics = np.zeros(
            (len(dataset) // subsampling + 1 if not small_run else small_run,
             len(metric_list)))
        entry_list = []
        outputs = []
        for i in range(depth_data.shape[0]):
            idx = i * subsampling

            if idx >= depth_data.shape[0] or (small_run and i >= small_run):
                break
            entry_list.append(idx)

            print("Evaluating {}[{}]".format(dataset_type, idx))
            spad = spad_data[idx, ...]
            # spad = preprocess_spad_ambient_estimate(spad, min_depth, max_depth,
            #                                             correct_falloff=use_squared_falloff,
            #                                             remove_dc= dc_count > 0.,
            #                                             global_min_depth=np.min(depth_data),
            #                                             n_std=1. if use_poisson else 0.01)
            # Rescale SPAD_data
            spad_rescaled = rescale_bins(spad, min_depth, max_depth, sid_obj)
            weights = np.ones_like(depth_data[idx, 0, ...])
            if use_intensity:
                weights = intensity_data[idx, 0, ...]
            # spad_rescaled = preprocess_spad_sid_gmm(spad_rescaled, sid_obj, use_squared_falloff, dc_count > 0.)
            if dc_count > 0.:
                spad_rescaled = remove_dc_from_spad(
                    spad_rescaled,
                    sid_obj.sid_bin_edges,
                    sid_obj.sid_bin_values[:-2]**2,
                    lam=1e-1 if spad_config["use_poisson"] else 1e-1,
                    eps_rel=1e-5)
                # spad_rescaled = remove_dc_from_spad_poisson(spad_rescaled,
                #                                        sid_obj.sid_bin_edges,
                #                                        lam=lam)
                # spad = remove_dc_from_spad_ambient_estimate(spad,
                #                                             min_depth, max_depth,
                #                                             global_min_depth=np.min(depth_data),
                #                                             n_std=n_std)
                # print(spad[:10])
                # print(spad)

            if use_squared_falloff:
                spad_rescaled = spad_rescaled * sid_obj.sid_bin_values[:-2]**2
                # bin_edges = np.linspace(min_depth, max_depth, len(spad) + 1)
                # bin_values = (bin_edges[1:] + bin_edges[:-1])/2
                # spad = spad * bin_values ** 2
            # spad_rescaled = rescale_bins(spad, min_depth, max_depth, sid_obj)
            pred, _ = image_histogram_match(depth_data[idx, 0, ...],
                                            spad_rescaled, weights, sid_obj)
            # break
            # Calculate metrics
            gt = dataset[idx]["depth_cropped"].unsqueeze(0)
            # print(gt.dtype)
            # print(pred.shape)
            # print(pred[20:30, 20:30])

            pred_metrics = get_depth_metrics(
                torch.from_numpy(pred).unsqueeze(0).unsqueeze(0).float(), gt,
                torch.ones_like(gt))

            for j, metric_name in enumerate(metric_list[:-1]):
                metrics[i, j] = pred_metrics[metric_name]

            metrics[i, -1] = np.size(pred)
            # Option to save outputs:
            if save_outputs:
                outputs.append(pred)
            print("\tAvg RMSE = {}".format(
                np.mean(metrics[:i + 1, metric_list.index("rmse")])))

        if save_outputs:
            np.save(
                os.path.join(output_dir,
                             "densedepth_{}_outputs.npy".format(hyper_string)),
                np.array(outputs))

        # Save metrics using pandas
        metrics_df = pd.DataFrame(data=metrics,
                                  index=entry_list,
                                  columns=metric_list)
        metrics_df.to_pickle(path=os.path.join(
            output_dir, "densedepth_{}_metrics.pkl".format(hyper_string)))
        # Compute weighted averages:
        average_metrics = np.average(metrics_df.ix[:, :-1],
                                     weights=metrics_df.weight,
                                     axis=0)
        average_df = pd.Series(data=average_metrics, index=metric_list[:-1])
        average_df.to_csv(os.path.join(
            output_dir, "densedepth_{}_avg_metrics.csv".format(hyper_string)),
                          header=True)
        print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
            'd1', 'd2', 'd3', 'rel', 'rmse', 'log_10'))
        print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".
              format(average_metrics[0], average_metrics[1],
                     average_metrics[2], average_metrics[3],
                     average_metrics[4], average_metrics[6]))

        print("wrote results to {} ({})".format(output_dir, hyper_string))

    else:
        input_unbatched = dataset.get_item_by_id(entry)
        # for key in ["rgb", "albedo", "rawdepth", "spad", "mask", "rawdepth_orig", "mask_orig", "albedo_orig"]:
        #     input_[key] = input_[key].unsqueeze(0)
        from torch.utils.data._utils.collate import default_collate

        data = default_collate([input_unbatched])

        # Checks
        entry = data["entry"][0]
        i = int(entry)
        entry = entry if isinstance(entry, str) else entry.item()
        print("Evaluating {}[{}]".format(dataset_type, i))
        # Rescale SPAD
        spad = spad_data[i, ...]
        spad_rescaled = rescale_bins(spad, min_depth, max_depth, sid_obj)
        print("spad_rescaled", spad_rescaled)
        weights = np.ones_like(depth_data[i, 0, ...])
        if use_intensity:
            weights = intensity_data[i, 0, ...]
        # spad_rescaled = preprocess_spad_sid_gmm(spad_rescaled, sid_obj, use_squared_falloff, dc_count > 0.)
        # spad_rescaled = preprocess_spad_sid(spad_rescaled, sid_obj, use_squared_falloff, dc_count > 0.
        #                                     )

        if dc_count > 0.:
            spad_rescaled = remove_dc_from_spad(
                spad_rescaled,
                sid_obj.sid_bin_edges,
                sid_obj.sid_bin_values[:-2]**2,
                lam=1e1 if use_poisson else 1e-1,
                eps_rel=1e-5)
        if use_squared_falloff:
            spad_rescaled = spad_rescaled * sid_obj.sid_bin_values[:-2]**2
        # print(spad_rescaled)
        pred, _ = image_histogram_match(depth_data[i, 0, ...], spad_rescaled,
                                        weights, sid_obj)
        # break
        # Calculate metrics
        gt = data["depth_cropped"]
        print(gt.shape)
        print(pred.shape)
        print(gt[:, :, 40, 60])
        print(depth_data[i, 0, 40, 60])
        print("before rmse: ",
              np.sqrt(np.mean((gt.numpy() - depth_data[i, 0, ...])**2)))

        before_metrics = get_depth_metrics(
            torch.from_numpy(
                depth_data[i, 0, ...]).unsqueeze(0).unsqueeze(0).float(), gt,
            torch.ones_like(gt))
        pred_metrics = get_depth_metrics(
            torch.from_numpy(pred).unsqueeze(0).unsqueeze(0).float(), gt,
            torch.ones_like(gt))
        if save_outputs:
            np.save(
                os.path.join(
                    output_dir, "densedepth_{}[{}]_{}_out.npy".format(
                        dataset_type, entry, hyper_string)), pred)

        print("before:")
        print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
            'd1', 'd2', 'd3', 'rel', 'rmse', 'log_10'))
        print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".
              format(before_metrics["delta1"], before_metrics["delta2"],
                     before_metrics["delta3"], before_metrics["rel_abs_diff"],
                     before_metrics["rmse"], before_metrics["log10"]))
        print("after:")

        print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
            'd1', 'd2', 'd3', 'rel', 'rmse', 'log_10'))
        print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".
              format(pred_metrics["delta1"], pred_metrics["delta2"],
                     pred_metrics["delta3"], pred_metrics["rel_abs_diff"],
                     pred_metrics["rmse"], pred_metrics["log10"]))
def run(dataset_type, spad_file, dorn_depth_file, hyper_string, sid_bins,
        alpha, beta, offset, lam, eps_rel, entry, save_outputs, small_run,
        output_dir):
    # Load all the data:
    print("Loading SPAD data from {}".format(spad_file))
    spad_dict = np.load(spad_file, allow_pickle=True).item()
    spad_data = spad_dict["spad"]
    intensity_data = spad_dict["intensity"]
    spad_config = spad_dict["config"]
    print("Loading depth data from {}".format(dorn_depth_file))
    depth_data = np.load(dorn_depth_file)
    dataset = load_data(channels_first=True, dataset_type=dataset_type)

    # Read SPAD config and determine proper course of action
    dc_count = spad_config["dc_count"]
    ambient = spad_config["dc_count"] / spad_config["spad_bins"]
    use_intensity = spad_config["use_intensity"]
    use_squared_falloff = spad_config["use_squared_falloff"]
    min_depth = spad_config["min_depth"]
    max_depth = spad_config["max_depth"]

    print("dc_count: ", dc_count)
    print("use_intensity: ", use_intensity)
    print("use_squared_falloff:", use_squared_falloff)

    print("spad_data.shape", spad_data.shape)
    print("depth_data.shape", depth_data.shape)
    print("intensity_data.shape", intensity_data.shape)

    sid_obj = SID(sid_bins, alpha, beta, offset)

    if entry is None:
        metric_list = [
            "delta1", "delta2", "delta3", "rel_abs_diff", "rmse", "mse",
            "log10", "weight"
        ]
        metrics = np.zeros(
            (len(dataset) if not small_run else small_run, len(metric_list)))
        entry_list = []
        outputs = []
        for i in range(depth_data.shape[0]):
            if small_run and i == small_run:
                break
            entry_list.append(i)

            print("Evaluating {}[{}]".format(dataset_type, i))
            spad = spad_data[i, ...]
            weights = np.ones_like(depth_data[i, 0, ...])
            if use_intensity:
                weights = intensity_data[i, 0, ...]
            if dc_count > 0.:
                spad = remove_dc_from_spad_edge(spad,
                                                ambient=ambient,
                                                grad_th=3 * ambient)
            if use_squared_falloff:
                bin_edges = np.linspace(min_depth, max_depth, len(spad) + 1)
                bin_values = (bin_edges[1:] + bin_edges[:-1]) / 2
                spad = spad * bin_values**2
            spad_rescaled = rescale_bins(spad, min_depth, max_depth, sid_obj)
            pred, _ = image_histogram_match(depth_data[i, 0, ...],
                                            spad_rescaled, weights, sid_obj)
            # break
            # Calculate metrics
            gt = dataset[i]["depth_cropped"].unsqueeze(0)
            # print(gt.dtype)
            # print(pred.shape)

            pred_metrics = get_depth_metrics(
                torch.from_numpy(pred).unsqueeze(0).unsqueeze(0).float(), gt,
                torch.ones_like(gt))

            for j, metric_name in enumerate(metric_list[:-1]):
                metrics[i, j] = pred_metrics[metric_name]

            metrics[i, -1] = np.size(pred)
            # Option to save outputs:
            if save_outputs:
                outputs.append(pred)

        if save_outputs:
            np.save(
                os.path.join(output_dir,
                             "dorn_{}_outputs.npy".format(hyper_string)),
                np.array(outputs))

        # Save metrics using pandas
        metrics_df = pd.DataFrame(data=metrics,
                                  index=entry_list,
                                  columns=metric_list)
        metrics_df.to_pickle(path=os.path.join(
            output_dir, "dorn_{}_metrics.pkl".format(hyper_string)))
        # Compute weighted averages:
        average_metrics = np.average(metrics_df.ix[:, :-1],
                                     weights=metrics_df.weight,
                                     axis=0)
        average_df = pd.Series(data=average_metrics, index=metric_list[:-1])
        average_df.to_csv(os.path.join(
            output_dir, "dorn_{}_avg_metrics.csv".format(hyper_string)),
                          header=True)
        print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
            'd1', 'd2', 'd3', 'rel', 'rms', 'log_10'))
        print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".
              format(average_metrics[0], average_metrics[1],
                     average_metrics[2], average_metrics[3],
                     average_metrics[4], average_metrics[6]))

        print("wrote results to {} ({})".format(output_dir, hyper_string))

    else:
        input_unbatched = dataset.get_item_by_id(entry)
        # for key in ["rgb", "albedo", "rawdepth", "spad", "mask", "rawdepth_orig", "mask_orig", "albedo_orig"]:
        #     input_[key] = input_[key].unsqueeze(0)
        from torch.utils.data._utils.collate import default_collate

        data = default_collate([input_unbatched])

        # Checks
        entry = data["entry"][0]
        i = int(entry)
        entry = entry if isinstance(entry, str) else entry.item()
        print("Evaluating {}[{}]".format(dataset_type, i))
        # Rescale SPAD
        spad_rescaled = rescale_bins(spad_data[i, ...], min_depth, max_depth,
                                     sid_obj)
        weights = np.ones_like(depth_data[i, 0, ...])
        if use_intensity:
            weights = intensity_data[i, 0, ...]
        spad_rescaled = preprocess_spad(spad_rescaled,
                                        sid_obj,
                                        use_squared_falloff,
                                        dc_count > 0.,
                                        lam=lam,
                                        eps_rel=eps_rel)

        pred, _ = image_histogram_match(depth_data[i, 0, ...], spad_rescaled,
                                        weights, sid_obj)
        # break
        # Calculate metrics
        gt = data["depth_cropped"]
        print(gt.shape)
        print(pred.shape)

        pred_metrics = get_depth_metrics(
            torch.from_numpy(pred).unsqueeze(0).unsqueeze(0), gt,
            torch.ones_like(gt))
        if save_outputs:
            np.save(
                os.path.join(
                    output_dir,
                    "dorn_{}[{}]_{}_out.npy".format(dataset_type, entry,
                                                    hyper_string)))
        print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(
            'd1', 'd2', 'd3', 'rel', 'rms', 'log_10'))
        print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".
              format(pred_metrics["delta1"], pred_metrics["delta2"],
                     pred_metrics["delta3"], pred_metrics["rel_abs_diff"],
                     pred_metrics["rms"], pred_metrics["log10"]))