Example #1
0
def get_dataset_item_dimensions_3D(dataset_base_dir,
                                   dataset_item,
                                   test_only=False):
    dataset_render_file_path = get_dataset_item_render_path(
        dataset_base_dir, dataset_item)
    render_grid = pyopenvdb.read(dataset_render_file_path, "rendering")
    num_surface_voxels = render_grid.activeVoxelCount()
    return {
        "num_surface_voxels": num_surface_voxels,
    }
def predict_one_3d(
    discrete_voxels_file: str,
    filled_vdb_path: str,
    materials_filename: str,
    model_filename: str,
    model_config_filename: str,
):

    if not os.path.abspath(discrete_voxels_file):
        discrete_voxels_file = os.path.abspath(discrete_voxels_file)

    if not os.path.isabs(filled_vdb_path):
        filled_vdb_path = os.path.abspath(filled_vdb_path)

    input_folder, input_filename = os.path.split(discrete_voxels_file)
    dataset_key = "single_volume"

    with open(model_config_filename, "r") as json_file:
        model_config = json.load(json_file)

    predictions, _, _ = predict_volume_appearance(
        dataset_base_dir=input_folder,
        dataset={
            dataset_key: {
                "filename": input_filename,
                "proxy_object_filled_path": filled_vdb_path,
                "metadata": {
                    "num_surface_voxels":
                    pyopenvdb.read(filled_vdb_path,
                                   "normalGrid").activeVoxelCount(),
                    "tile_size":
                    24,
                },
            }
        },
        ignore_md5_checks=True,
        model_filename=model_filename,
        model_config=model_config,
        materials_filename=materials_filename,
        stencils_only=True,
    )
    prediction = predictions[dataset_key]

    return prediction
def main():
    args = get_args()

    discrete_voxels_file = args.discrete_voxels_file
    gt_rendering_path = args.gt_rendering_path
    output_vdb_path = args.output_vdb_path

    if gt_rendering_path:
        gt_rendering = pyopenvdb.read(gt_rendering_path, "rendering")
    else:
        gt_rendering = None

    prediction = predict_one_3d(
        discrete_voxels_file=discrete_voxels_file,
        filled_vdb_path=args.filled_vdb_path,
        materials_filename=args.materials_file,
        model_filename=os.path.join(args.model_folder, args.model_weights),
        model_config_filename=os.path.join(args.model_folder,
                                           args.model_config),
    )
    pyopenvdb.write(output_vdb_path, [prediction])

    if gt_rendering is not None:
        prediction_acc = prediction.getConstAccessor()
        diff_list = []
        for item in gt_rendering.citerOnValues():
            if item.count == 1:
                target = item.value
                prediction, prediction_active = prediction_acc.probeValue(
                    item.min)
                if prediction_active:
                    diff_list.append([target, prediction])
        if diff_list:
            array = numpy.array(diff_list)
        else:
            array = numpy.zeros((1, 2, 3))
        gt_rendering = array[:, 0, :]
        prediction = array[:, 1, :]

        print("RMSE:", rms(gt_rendering, prediction))
        print("RMSE SRGB:",
              rms(linear_to_sRGB(gt_rendering), linear_to_sRGB(prediction)))
    def get_render(self, dataset_item_key=None, data_object=None, make_index=False):
        dataset_item = self.data_items[dataset_item_key]
        filled_path = dataset_item["proxy_object_filled_path"]
        render_path = dataset_item["render_path"]

        if filled_path is None and render_path is None:
            raise ValueError(
                "One of: `proxy_object_filled_path`, `render_filename` "
                "is needed for Data3D but neither is given"
            )

        index_cache_path_suffix = (
            "_index_cache_inner_voxel.npz"
            if self.find_inner_material_voxels
            else "_index_cache.npz"
        )
        index_cache_path = None
        if filled_path:
            index_cache_path = filled_path.replace("_filled.vdb", index_cache_path_suffix)
        if render_path:
            substring_index = render_path.find("_render_result")
            if substring_index > 0:
                index_cache_path_old = render_path[:substring_index] + index_cache_path_suffix
            else:
                raise RuntimeError("Render path does not containt substring render_result")
            index_cache_path = index_cache_path_old

        label_grid = data_object
        rendering_grid = pyopenvdb.read(render_path, "rendering") if render_path else None
        normal_grid = pyopenvdb.read(
            render_path if render_path is not None else filled_path, "normalGrid"
        )
        normal_grid_acc = normal_grid.getConstAccessor()

        if make_index:
            if index_cache_path and os.path.exists(index_cache_path):
                d = numpy.load(index_cache_path, allow_pickle=True)
                discrete_voxels_index = d["discrete_voxels_index"]
                material_voxel_found_index = d["material_voxel_found_index"]
                render_index = d["render_index"]
                try:
                    distances = d["distances"]
                except KeyError:  # temporary backward compatibility code
                    s = time()
                    with ThreadPool() as pool:
                        render_index_tuples = [
                            (int(r[0]), int(r[1]), int(r[2])) for r in render_index
                        ]
                        render_voxels_world = numpy.array(
                            pool.map(normal_grid.transform.indexToWorld, render_index_tuples)
                        )
                    discrete_voxels_world = (
                        label_grid.bbox_min + discrete_voxels_index * label_grid.voxel_size
                    )
                    distances = numpy.linalg.norm(
                        discrete_voxels_world - render_voxels_world, axis=1
                    )
                    logger.info(
                        "Calculating distances for {} took {} seconds".format(
                            dataset_item_key, time() - s
                        )
                    )
                    numpy.savez_compressed(
                        index_cache_path,
                        discrete_voxels_index=discrete_voxels_index,
                        material_voxel_found_index=material_voxel_found_index,
                        render_index=render_index,
                        distances=distances,
                    )
            else:
                logger.info("Building indexes for {}".format(dataset_item_key))

                if "find_inner_voxel" not in self.timings:
                    self.timings["find_inner_voxel"] = 0.0

                discrete_voxels_index = numpy.empty(
                    shape=(dataset_item["metadata"]["num_surface_voxels"], 3), dtype=numpy.int32
                )
                render_index = numpy.empty(
                    shape=(dataset_item["metadata"]["num_surface_voxels"], 3), dtype=numpy.int32
                )
                material_voxel_found_index = numpy.empty(
                    shape=(dataset_item["metadata"]["num_surface_voxels"],), dtype=numpy.bool
                )
                world_coord_index = numpy.empty(
                    shape=(dataset_item["metadata"]["num_surface_voxels"], 3), dtype=numpy.float64
                )
                normal_index = numpy.empty(
                    shape=(dataset_item["metadata"]["num_surface_voxels"], 3), dtype=numpy.float64
                )
                index_counter = 0
                if rendering_grid:
                    for item in rendering_grid.citerOnValues():
                        if item.count == 1:  # voxel value
                            grid_coord = item.min
                            render_index[index_counter] = grid_coord
                            world_coord_index[
                                index_counter
                            ] = rendering_grid.transform.indexToWorld(grid_coord)
                            normal_index[index_counter], active = normal_grid_acc.probeValue(
                                grid_coord
                            )
                            assert active
                            index_counter += 1
                else:
                    for item in normal_grid.citerOnValues():
                        if item.count == 1:  # voxel value
                            grid_coord = item.min
                            render_index[index_counter] = grid_coord
                            world_coord_index[index_counter] = normal_grid.transform.indexToWorld(
                                grid_coord
                            )
                            normal_index[index_counter] = item.value
                            index_counter += 1
                if label_grid is not None:
                    calculate_discrete_volume_index_func = partial(
                        calculate_discrete_volume_index,
                        label_grid,
                        world_coord_index,
                        normal_index,
                        self.find_inner_material_voxels,
                    )
                    ts = time()
                    with Pool() as pool:
                        indexes = pool.map(
                            calculate_discrete_volume_index_func, range(len(render_index))
                        )
                    for idx, (coord, found) in enumerate(indexes):
                        discrete_voxels_index[idx] = coord
                        material_voxel_found_index[idx] = found
                    find_inner_voxel_timing = time() - ts
                    if self.verbose_logging:
                        logger.info(
                            "find_inner_voxel timing: {} seconds".format(
                                round(find_inner_voxel_timing, 3)
                            )
                        )
                    self.timings["find_inner_voxel"] += find_inner_voxel_timing

                    discrete_voxels_world = (
                        label_grid.bbox_min + discrete_voxels_index * label_grid.voxel_size
                    )
                    distances = numpy.linalg.norm(
                        discrete_voxels_world - world_coord_index, axis=1
                    )
                else:
                    distances = None

                numpy.savez_compressed(
                    index_cache_path,
                    discrete_voxels_index=discrete_voxels_index,
                    material_voxel_found_index=material_voxel_found_index,
                    render_index=render_index,
                    distances=distances,
                )
        else:
            discrete_voxels_index = None
            material_voxel_found_index = None
            render_index = None
            distances = None

        if label_grid is None:
            del discrete_voxels_index
            discrete_voxels_index = None

        return (
            rendering_grid,
            render_index,
            discrete_voxels_index,
            material_voxel_found_index,
            distances,
        )
def main():
    args = get_args()

    dataset_base_dir = os.path.abspath(os.path.split(args.dataset)[0])
    dataset_name = os.path.splitext(os.path.split(args.dataset)[1])[0]

    with open(args.dataset) as f:
        dataset_data = yaml.full_load(f)

    model_folder = os.path.normpath(args.model_folder)

    with open(os.path.join(model_folder, args.model_config), "r") as json_file:
        model_config = json.load(json_file)

    _, nn_model_name = os.path.split(model_folder)
    results_output_directory = os.path.join(args.base_output_directory,
                                            dataset_name, nn_model_name)

    model_data = {
        "name": nn_model_name,
        "config": model_config,
    }

    model_weights_filepath = os.path.join(model_folder, args.model_weights)
    model_weights_file_md5 = md5sum_path(model_weights_filepath)
    output_yaml_file = os.path.join(results_output_directory, "summary.yml")

    if os.path.exists(output_yaml_file):
        with open(output_yaml_file, "r") as yaml_file:
            previous_prediction_data = yaml.full_load(yaml_file)
    else:
        previous_prediction_data = None

    results_data = {}
    dataset = dataset_data.get("items", {})

    predictions = {}
    prediction_masks = {}
    gt_renderings = {}
    to_predict_dataset = {}

    predicted_items = set()

    for item_key, dataset_item in dataset.items():
        dataset_item["is_2D"] = is_2D_file(
            get_dataset_item_volume_path(dataset_base_dir, dataset_item))

        output_prediction_path = os.path.join(
            results_output_directory,
            "{}.{}".format(item_key,
                           "exr" if dataset_item["is_2D"] else "vdb"),
        )
        dataset_item["output_prediction_path"] = output_prediction_path

        previous_prediction_md5 = ((previous_prediction_data.get(
            "items", {}).get(item_key, {}).get("model_weights_file_md5"))
                                   if previous_prediction_data else None)
        md5_recalculate = (not args.skip_md5_check and previous_prediction_data
                           and previous_prediction_md5 is not None and
                           previous_prediction_md5 != model_weights_file_md5)

        if not os.path.exists(
                output_prediction_path) or args.recalculate or md5_recalculate:
            to_predict_dataset[item_key] = dataset_item
        else:
            render_filename = get_dataset_item_render_path(
                dataset_base_dir, dataset_item)
            if dataset_item["is_2D"]:
                predictions[item_key] = read_image(output_prediction_path)
                gt_renderings[item_key] = read_image(render_filename)
            else:
                predictions[item_key] = pyopenvdb.read(output_prediction_path,
                                                       "prediction")
                gt_renderings[item_key] = pyopenvdb.read(
                    render_filename, "rendering")

    prediction_timings = {}
    if len(to_predict_dataset.keys()):
        predictions_new, gt_renderings_new, prediction_masks_new = predict_volume_appearance(
            dataset=to_predict_dataset,
            dataset_base_dir=dataset_base_dir,
            materials_filename=dataset_data["materials_file"],
            model_config=model_config,
            model_filename=model_weights_filepath,
            stencils_only=False,
            timings=prediction_timings,
            verbose_logging=args.verbose,
        )
        for item_key, prediction in predictions_new.items():
            dataset_item = dataset[item_key]
            folder_path, file_name = os.path.split(
                dataset_item["output_prediction_path"])
            ensure_dir(folder_path)
            if dataset_item["is_2D"]:
                dump_image(image=prediction,
                           filepath=dataset_item["output_prediction_path"])
            else:
                pyopenvdb.write(dataset_item["output_prediction_path"],
                                [prediction])
            predicted_items.add(item_key)
        predictions.update(predictions_new)
        if prediction_masks_new:
            prediction_masks.update(prediction_masks_new)
        gt_renderings.update(gt_renderings_new)

    if args.verbose:
        print("Prediction timings:")
        pprint(prediction_timings)

    # That's temporary for the deadline - normalize diff images
    max_diff_pos, max_diff_neg = 0.0, 0.0

    for item_key, dataset_item in dataset.items():
        prediction = predictions[item_key]
        gt_rendering = gt_renderings[item_key]

        if not dataset_item["is_2D"]:
            prediction_grid = prediction
            prediction_acc = prediction_grid.getConstAccessor()
            try:
                prediction_mask_accessor = prediction_masks[
                    item_key].getConstAccessor()
            except KeyError:
                prediction_mask_accessor = None

            coords_list = []
            diff_list = []

            for item in gt_rendering.citerOnValues():
                if item.count == 1:
                    target = item.value
                    prediction, prediction_active = prediction_acc.probeValue(
                        item.min)
                    if prediction_mask_accessor:
                        mask_value = prediction_mask_accessor.getValue(
                            item.min)
                    else:
                        mask_value = True
                    if prediction_active and mask_value:
                        coords_list.append(item.min)
                        diff_list.append([target, prediction])
            if diff_list:
                array = numpy.array(diff_list)
            else:
                array = numpy.zeros((1, 2, 3))
            gt_rendering = array[:, 0, :]
            prediction = array[:, 1, :]

            diff = prediction - gt_rendering
            diff_vis_scaling = 4.0
            diff_vis_array = plt.get_cmap(get_colormap("error"))(
                ((diff_vis_scaling * diff) / 2.0 + 0.5).mean(axis=1))

            diff_positive = numpy.maximum(diff, 0.0)
            diff_negative = numpy.minimum(diff, 0.0)
            max_diff_pos = max(max_diff_pos, diff_positive.mean(axis=1).max())
            max_diff_neg = max(max_diff_neg,
                               (-1.0 * diff_negative.mean(axis=1)).max())

            diff_grid = pyopenvdb.Vec3SGrid((-1, -1, -1))
            diff_grid.transform = prediction_grid.transform
            diff_grid.name = "diff_red-green"
            diff_grid_accessor = diff_grid.getAccessor()
            for coord, diff_vis_value in zip(coords_list, diff_vis_array):
                diff_grid_accessor.setValueOn(
                    coord,
                    (diff_vis_value[0], diff_vis_value[1], diff_vis_value[2]))

            diff_dE = get_difference_metric("ciede2000")(gt_rendering,
                                                         prediction)
            diff_dE_vis_scaling = 20.0
            diff_dE_vis_array = plt.get_cmap(get_colormap("ciede2000"))(
                diff_dE / diff_dE_vis_scaling)

            diff_dE_grid = pyopenvdb.Vec3SGrid((-1, -1, -1))
            diff_dE_grid.transform = prediction_grid.transform
            diff_dE_grid.name = "diff_dE2000_20max"
            diff_dE_grid_accessor = diff_dE_grid.getAccessor()
            for coord, diff_vis_value in zip(coords_list,
                                             diff_dE_vis_array[0]):
                diff_dE_grid_accessor.setValueOn(
                    coord,
                    (diff_vis_value[0], diff_vis_value[1], diff_vis_value[2]))
            pyopenvdb.write(dataset_item["output_prediction_path"],
                            [prediction_grid, diff_grid, diff_dE_grid])

        rmse_linear = float(
            get_difference_metric("rms")(gt_rendering, prediction))
        rmse_srgb = float(
            get_difference_metric("rms")(linear_to_sRGB(gt_rendering),
                                         linear_to_sRGB(prediction)))
        if dataset_item["is_2D"]:
            ssim_linear = float(
                get_difference_metric("ssim")(gt_rendering, prediction))
            ssim_srgb = float(
                get_difference_metric("ssim")(linear_to_sRGB(gt_rendering),
                                              linear_to_sRGB(prediction)))
        else:
            ssim_linear = 0.0
            ssim_srgb = 0.0

        print(item_key, "RMSE:", rmse_linear)
        print(item_key, "RMSE SRGB:", rmse_srgb)
        print(item_key, "SSIM:", ssim_linear)
        print(item_key, "SSIM SRGB:", ssim_srgb)

        volume_filename = get_dataset_item_volume_path(dataset_base_dir,
                                                       dataset_item)
        render_filename = get_dataset_item_render_path(dataset_base_dir,
                                                       dataset_item)

        results_data[item_key] = {
            "volume_filename": os.path.abspath(volume_filename),
            "render_filename": os.path.abspath(render_filename),
            "prediction_filename": os.path.abspath(output_prediction_path),
            "base_image_name": dataset_item.get("base_image_name", ""),
            "rmse_linear": rmse_linear,
            "rmse_srgb": rmse_srgb,
            "ssim_linear": ssim_linear,
            "ssim_srgb": ssim_srgb,
            "model_weights_file_md5": model_weights_file_md5,
        }

    print("max_diff_pos=", max_diff_pos, "max_diff_neg=", max_diff_neg)

    with open(output_yaml_file, "w") as f:
        yaml.dump({"model": model_data, "items": results_data}, f)
Example #6
0
def predict_volume_appearance(
    dataset_base_dir: str,
    dataset: dict,
    model_filename,
    model_config,
    materials_filename,
    stencils_only,
    timings: Dict = None,
    ignore_md5_checks: bool = False,
    verbose_logging: bool = False,
):
    model_params = model_config["model_params"]

    if not os.path.isabs(materials_filename):
        materials_filename = os.path.join(dataset_base_dir, materials_filename)

    # backward compatibility
    if type(model_params["stencil_channels"]) == int:
        model_params["stencil_channels"] = [
            "scattering", "absorption", "mask"
        ][:model_params["stencil_channels"]]

    model_arch_name = model_params["model_arch_name"]
    patch_size = model_params["patch_size"]
    scale_levels = model_params["scale_levels"]
    stencil_channels = model_params["stencil_channels"]

    is_2D_dataset = classify_dataset_class(dataset_base_dir, dataset)

    data_class = DataPlanar if is_2D_dataset else Data3D
    alignment_z_centered = model_params.get("alignment_z_centered",
                                            data_class == Data3D)
    data = data_class(
        alignment_z_centered=alignment_z_centered,
        data_items=dataset,
        dataset_base_dir=dataset_base_dir,
        find_inner_material_voxels=model_params["find_inner_material_voxels"],
        ignore_md5_checks=ignore_md5_checks,
        materials_file=materials_filename,
        mode=MODE_PREDICT,
        patch_size=patch_size,
        sat_object_class_name="TreeSAT",
        scale_levels=scale_levels,
        shuffle_patches=False,
        sliding_window_length=1,
        stencil_channels=stencil_channels,
        stencils_only=stencils_only,
        timings=timings,
        verbose_logging=verbose_logging,
    )

    batch_size = int(os.getenv("BATCH_SIZE", 10000))

    model_make_function = models_collection[model_arch_name]
    model = model_make_function(params=model_params)
    model.load_weights(model_filename)

    make_batch_function = (make_batch_swap_axes if model_arch_name
                           in ("planar_first",
                               "first_baseline") else make_batch)

    locations = []
    predictions = model.predict_generator(
        generator=make_batches_gen(data, batch_size, locations,
                                   make_batch_function),
        steps=math.ceil(len(data) / batch_size),
        verbose=1,
    )

    predicted_images = {}
    predicted_images_accessors = {}
    gt_renderings = {}
    predicted_masks = None
    predicted_masks_accessors = {}

    materials = populate_materials(materials_filename)
    material_channels = [m.channels for m in materials.values()]
    assert max(material_channels) == min(
        material_channels
    ), "number of channels in materials file mismatch"  # all elements equal
    materials_channel_count = material_channels[0]

    if materials_channel_count != 3:
        # spectral mode
        material_wavelengths = [m.wavelengths for m in materials.values()]
        for i in range(1, len(material_wavelengths)):
            assert numpy.array_equal(
                material_wavelengths[i - 1], material_wavelengths[i]
            ), "wavelength definition mismatch in materials file"
        # spectral prediction
        X, Y, Z = CIEXYZ_primaries(material_wavelengths[0] *
                                   10)  # convert nm to Angstrom

    if is_2D_dataset:
        for (datafile_idx, channel,
             pidx), pixel_prediction in zip(locations, predictions):
            dataset_item_key, metadata = data.volume_files[datafile_idx]
            if dataset_item_key not in predicted_images.keys():
                if not stencils_only:
                    gt_renderings[dataset_item_key] = data.get_render(
                        dataset_item_key=dataset_item_key)

                predicted_images[dataset_item_key] = numpy.empty(
                    shape=(metadata["height"], metadata["width"], 3),
                    dtype=numpy.float32)
            x, y, z = data.convert_patch_index(datafile_idx, channel, pidx)
            predicted_images[dataset_item_key][y, x,
                                               channel] = pixel_prediction
    else:  # is 3D dataset
        material_voxel_found_index_masks = {
            dataset_item_key: data.material_voxel_found_index[datafile_idx]
            for datafile_idx, (dataset_item_key,
                               _) in enumerate(data.volume_files)
        }
        predicted_masks = {}

        locations_predictions = (pd.DataFrame(
            numpy.concatenate(
                [numpy.array(locations),
                 numpy.array(predictions)], axis=1),
            columns=["datafile_idx", "channel", "pidx", "value"],
        ).astype({
            "datafile_idx": "uint16",
            "channel": "uint8",
            "pidx": "uint32",
        }).sort_values(by=["datafile_idx", "pidx", "channel"]))
        num_channels = int(locations_predictions["channel"].max()) + 1

        for lp_idx in range(0, len(locations_predictions), num_channels):
            datafile_idx = int(locations_predictions.at[lp_idx,
                                                        "datafile_idx"])
            pidx = int(locations_predictions.at[lp_idx, "pidx"])

            dataset_item_key, metadata = data.volume_files[int(datafile_idx)]
            if dataset_item_key not in predicted_images.keys():
                predicted_images[dataset_item_key] = pyopenvdb.Vec3SGrid(
                    (-1, -1, -1))
                predicted_images_accessors[
                    dataset_item_key] = predicted_images[
                        dataset_item_key].getAccessor()

                predicted_masks[dataset_item_key] = pyopenvdb.BoolGrid(False)
                predicted_masks_accessors[dataset_item_key] = predicted_masks[
                    dataset_item_key].getAccessor()

                render, _, _, _, _ = data.get_render(
                    dataset_item_key=dataset_item_key)
                gt_renderings[dataset_item_key] = render
                if render is not None:
                    predicted_images[
                        dataset_item_key].transform = render.transform
                else:
                    dataset_item = data.data_items[dataset_item_key]
                    filled_path = dataset_item["proxy_object_filled_path"]
                    if filled_path is not None:
                        normal_grid = pyopenvdb.read(filled_path, "normalGrid")
                        predicted_images[
                            dataset_item_key].transform = normal_grid.transform
                    else:
                        raise RuntimeError(
                            "One of: `render_filename`, `proxy_object_filled_path` "
                            "should be set to extract transform for the predicted grid"
                        )

                predicted_images[dataset_item_key].name = "prediction"
            x, y, z = data.convert_patch_index_to_render_coords(
                datafile_idx, 0, pidx)

            if num_channels == 3:
                value = (
                    locations_predictions.at[lp_idx + 0, "value"],
                    locations_predictions.at[lp_idx + 1, "value"],
                    locations_predictions.at[lp_idx + 2, "value"],
                )
            else:
                # convolve the renderings with the XYZ color matching functions
                value = [0.0, 0.0, 0.0]
                for i, wavelength in enumerate(material_wavelengths[0]):
                    value[0] += X[i] * locations_predictions.at[lp_idx + i,
                                                                "value"]
                    value[1] += Y[i] * locations_predictions.at[lp_idx + i,
                                                                "value"]
                    value[2] += Z[i] * locations_predictions.at[lp_idx + i,
                                                                "value"]
                value = tuple(xyz2linear_rgb([[value]])[0][0])
            predicted_images_accessors[dataset_item_key].setValueOn((x, y, z),
                                                                    value)
            mask_value = material_voxel_found_index_masks[dataset_item_key][
                pidx]
            predicted_masks_accessors[dataset_item_key].setValueOn((x, y, z),
                                                                   mask_value)

    return predicted_images, gt_renderings, predicted_masks