Exemple #1
0
def compress():
    assert args.resolution > 0, 'resolution must be positive'
    assert args.data_format in ['channels_first', 'channels_last']
    with_normals = args.input_normals is not None
    validate_opt_metrics(args.opt_metrics, with_normals=with_normals)

    files_mult = 1
    if with_normals:
        files_mult *= 2
        assert files_mult * len(args.input_files) == len(args.output_files)
        assert files_mult * len(args.input_normals) == len(args.output_files)
    else:
        assert files_mult * len(args.input_files) == len(args.output_files)
    decode_files = args.dec_files is not None
    if decode_files:
        assert files_mult * len(args.input_files) == len(args.dec_files)

    assert args.model_config in ModelConfigType.keys()

    p_min, p_max, dense_tensor_shape = pc_io.get_shape_data(
        args.resolution, args.data_format)
    points = pc_io.load_points(args.input_files,
                               batch_size=args.read_batch_size)
    if with_normals:
        normals = [
            PyntCloud.from_file(x).points[['nx', 'ny', 'nz']].values
            for x in args.input_normals
        ]
        points = [np.hstack((p, n)) for p, n in zip(points, normals)]

    logger.info('Performing octree partitioning')
    # Hardcode bbox_min
    bbox_min = [0, 0, 0]
    if args.data_format == 'channels_first':
        bbox_max = dense_tensor_shape[1:].copy()
        dense_tensor_shape[1:] = dense_tensor_shape[1:] // (2**
                                                            args.octree_level)
    else:
        bbox_max = dense_tensor_shape[:3].copy()
        dense_tensor_shape[:3] = dense_tensor_shape[:3] // (2**
                                                            args.octree_level)
    blocks_list, binstr_list = zip(*[
        partition_octree(p, bbox_min, bbox_max, args.octree_level)
        for p in points
    ])
    blocks_list_flat = [y for x in blocks_list for y in x]
    logger.info(
        f'Processing resolution {args.resolution} with octree level {args.octree_level} resulting in '
        +
        f'dense_tensor_shape {dense_tensor_shape} and {len(blocks_list_flat)} blocks'
    )

    batch_size = 1
    x_shape = np.concatenate(((batch_size, ), dense_tensor_shape))

    model = ModelConfigType[args.model_config].build()
    model.compress(x_shape)

    # Checkpoints
    saver = tf.train.Saver(keep_checkpoint_every_n_hours=1)
    init = tf.global_variables_initializer()
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    with tf.Session(config=tf_config) as sess:
        logger.info('Init session')
        sess.run(init)

        checkpoint = tf.train.latest_checkpoint(args.checkpoint_dir)
        assert checkpoint is not None, f'Checkpoint {args.checkpoint_dir} was not found'
        saver.restore(sess, checkpoint)

        for i in trange(len(args.input_files)):
            ori_file, cur_points, blocks, binstr = [
                x[i]
                for x in (args.input_files, points, blocks_list, binstr_list)
            ]
            n_blocks = len(blocks)

            cur_output_files = [
                args.output_files[i * files_mult + j]
                for j in range(files_mult)
            ]
            if decode_files:
                cur_dec_files = [
                    args.dec_files[i * files_mult + j]
                    for j in range(files_mult)
                ]
            assert len(set(cur_output_files)) == len(
                cur_output_files
            ), f'{cur_output_files} should have no duplicates'
            logger.info(
                f'Starting {ori_file} to {", ".join(cur_output_files)} with {n_blocks} blocks'
            )
            data_list, data, debug_t_list = model.compress_blocks(
                sess,
                blocks,
                binstr,
                cur_points,
                args.resolution,
                args.octree_level,
                with_normals=with_normals,
                opt_metrics=args.opt_metrics,
                max_deltas=args.max_deltas,
                fixed_threshold=args.fixed_threshold,
                debug=args.debug)
            assert len(data_list) == files_mult

            for j in range(len(cur_output_files)):
                of, cur_data_list, cur_data = [
                    x[j] for x in (cur_output_files, data_list, data)
                ]
                os.makedirs(os.path.split(of)[0], exist_ok=True)
                with gzip.open(of, "wb") as f:
                    ret = save_compressed_file(binstr, cur_data_list,
                                               args.resolution,
                                               args.octree_level)
                    f.write(ret)
                if decode_files:
                    pc_io.write_df(cur_dec_files[j],
                                   pc_io.pa_to_df(cur_data['blocks_full']))
                with open(of + '.enc.metric.json', 'w') as f:
                    json.dump(cur_data['metrics'], f, sort_keys=True, indent=4)
                if args.debug:
                    pc_io.write_df(of + '.enc.ply',
                                   pc_io.pa_to_df(cur_data['blocks_full']))

                    write_pcs(blocks, of + '.ori.blocks')
                    write_pcs(cur_data['x_hat_list'], of + '.enc.blocks')
                    write_pcs(cur_data['blocks_depart'],
                              of + '.enc.blocks.depart')
                    np.savez_compressed(of + '.enc.data.npz',
                                        data=cur_data_list,
                                        debug_t_list=debug_t_list)

            logger.info(
                f'Finished {ori_file} to {", ".join(cur_output_files)} with {n_blocks} blocks'
            )
def compute_optimal_thresholds(block,
                               x_hat,
                               thresholds,
                               resolution,
                               normals=None,
                               opt_metrics=['d1_mse'],
                               max_deltas=[np.inf],
                               fixed_threshold=False):
    validate_opt_metrics(opt_metrics, with_normals=normals is not None)
    assert len(max_deltas) > 0
    best_thresholds = []
    ret_opt_metrics = [
        f'{opt_metric}_{max_delta}' for max_delta in max_deltas
        for opt_metric in opt_metrics
    ]
    if fixed_threshold:
        half_thr = len(thresholds) // 2
        half_pa = np.argwhere(x_hat > thresholds[half_thr]).astype('float32')
        logger.info(
            f'Fixed threshold {half_thr}/{len(thresholds)} with {len(half_pa)}/{len(block)} points (ratio {len(half_pa)/len(block):.2f})'
        )
        return ret_opt_metrics, [half_thr] * len(max_deltas) * len(opt_metrics)

    pa_list = build_points_threshold(x_hat, thresholds, len(block))
    max_threshold_idx = len(thresholds) - 1
    if len(pa_list) == 0:
        return ret_opt_metrics, [max_threshold_idx] * len(opt_metrics)

    t1 = cKDTree(block[:, :3], balanced_tree=False)
    pa_metrics = [
        compute_metrics(block[:, :3], pa, resolution - 1, p1_n=normals, t1=t1)
        for _, pa in pa_list
    ]

    log_message = f'Processing max_deltas {max_deltas} on block with {len(block)} points'
    for max_delta in max_deltas:
        if max_delta is not None:
            cur_pa_list = build_points_threshold(x_hat, thresholds, len(block),
                                                 max_delta)
            if len(cur_pa_list) > 0:
                idx_mask = [x[0] for x in cur_pa_list]
                cur_pa_metrics = [pa_metrics[i] for i in idx_mask]
            else:
                cur_pa_list = pa_list
                cur_pa_metrics = pa_metrics
        else:
            cur_pa_list = pa_list
            cur_pa_metrics = pa_metrics
        log_message += f'\n{len(cur_pa_list)}/{len(thresholds)} thresholds eligible for max_delta {max_delta}'
        for opt_metric in opt_metrics:
            best_threshold_idx = np.argmin(
                [x[opt_metric] for x in cur_pa_metrics])
            cur_best_metric = cur_pa_metrics[best_threshold_idx][opt_metric]

            # Check for failure scenarios
            mean_point_metric = compute_metrics(
                block[:, :3],
                np.round(np.mean(block[:, :3], axis=0))[np.newaxis, :],
                resolution - 1,
                p1_n=normals,
                t1=t1)[opt_metric]
            # In case a single point is better than the network output, this is a failure case
            # Do not output any points
            if cur_best_metric > mean_point_metric:
                best_threshold_idx = max_threshold_idx
                final_idx = best_threshold_idx
                log_message += f', {opt_metric} {final_idx} 0/{len(block)}, metric {cur_best_metric:.2e} > mean point metric {mean_point_metric:.2e}'
            else:
                final_idx = cur_pa_list[best_threshold_idx][0]
                cur_n_points = len(cur_pa_list[best_threshold_idx][1])
                log_message += f', {opt_metric} {final_idx} {cur_n_points}/{len(block)} points (ratio {cur_n_points/len(block):.2f}) {cur_best_metric :.2e} < mean point metric {mean_point_metric:.2e}'
            best_thresholds.append(final_idx)
    logger.info(log_message)
    assert len(ret_opt_metrics) == len(best_thresholds)

    return ret_opt_metrics, best_thresholds
    parser.add_argument('output_path', help='Output folder.')
    args = parser.parse_args()

    with open(args.experiment_path, 'r') as f:
        experiments = yaml.load(f.read(), Loader=yaml.FullLoader)
    keys = [
        'MPEG_DATASET_DIR', 'EXPERIMENT_DIR', 'mpeg_modes', 'model_configs',
        'opt_metrics', 'eval_modes'
    ]
    MPEG_DATASET_DIR, EXPERIMENT_DIR, mpeg_modes, model_configs, opt_metrics, eval_modes = [
        experiments[k] for k in keys
    ]
    mpeg_path = os.path.join(EXPERIMENT_DIR, 'gpcc')
    eval_modes = index_by_id(eval_modes)

    validate_opt_metrics(opt_metrics, with_normals=True)
    os.makedirs(args.output_path, exist_ok=True)

    opt_groups = ['d1', 'd2']

    logger.info('Build tables')
    # eval_id, label, metric, mode_id, opt_group, pc_name
    bdsnr_df = pd.read_csv(os.path.join(EXPERIMENT_DIR, 'results',
                                        'bdsnr.csv'))

    # Alpha table
    alpha_bdsnr_ref = ['trisoup-predlift/lossy-geom-lossy-attrs']
    alpha_modes = index_by_id(eval_modes['alpha']['modes'])
    alpha_mode_ids = [x for x in alpha_modes if x not in alpha_bdsnr_ref]
    alpha_df = bdsnr_df.query(
        f'eval_id == "alpha" & mode_id in {alpha_mode_ids}')
def run_experiment(output_dir,
                   model_dir,
                   model_config,
                   pc_name,
                   pcerror_path,
                   pcerror_cfg_path,
                   input_pc,
                   input_norm,
                   opt_metrics,
                   max_deltas,
                   fixed_threshold,
                   no_merge_coding,
                   num_parallel,
                   no_stream_redirection=False):
    for f in [model_dir, pcerror_path, pcerror_cfg_path, input_pc, input_norm]:
        assert_exists(f)
    validate_opt_metrics(opt_metrics, with_normals=input_norm is not None)
    with open(pcerror_cfg_path, 'r') as f:
        pcerror_cfg = yaml.load(f.read(), Loader=yaml.FullLoader)

    opt_group = ['d1', 'd2']
    enc_pc_filenames = [f'{pc_name}_{x}.ply.bin' for x in opt_group]
    dec_pc_filenames = [f'{x}.ply' for x in enc_pc_filenames]
    dec_pc_color_filenames = [f'{x}.color.ply' for x in dec_pc_filenames]
    pcerror_result_filenames = [f'{x}.pc_error' for x in dec_pc_filenames]
    enc_pcs = [os.path.join(output_dir, x) for x in enc_pc_filenames]
    dec_pcs = [os.path.join(output_dir, x) for x in dec_pc_filenames]
    dec_pcs_color = [
        os.path.join(output_dir, x) for x in dec_pc_color_filenames
    ]
    pcerror_results = [
        os.path.join(output_dir, x) for x in pcerror_result_filenames
    ]
    exp_reports = [
        os.path.join(output_dir, f'report_{x}.json') for x in opt_group
    ]

    compress_log = os.path.join(output_dir, 'compress.log')
    decompress_log = os.path.join(output_dir, 'decompress.log')

    # Create folder
    os.makedirs(output_dir, exist_ok=True)

    resolution = pcerror_cfg['resolution']

    # Encoding or Encoding/Decoding with merge_coding option
    if all(os.path.exists(x)
           for x in enc_pcs) and (no_merge_coding
                                  or all(os.path.exists(x) for x in dec_pcs)):
        print_progress(input_pc, enc_pcs, '(exists)')
    else:
        print_progress(input_pc, enc_pcs)
        with ExitStack() as stack:
            if no_stream_redirection:
                f = None
            else:
                f = open(compress_log, 'w')
                stack.enter_context(f)
            additional_params = []
            if not no_merge_coding:
                additional_params += ['--dec_files', *dec_pcs]
            if fixed_threshold:
                additional_params += ['--fixed_threshold']
            subprocess.run(
                [
                    'python',
                    'compress_octree.py',  # '--debug',
                    '--input_files',
                    input_pc,
                    '--input_normals',
                    input_norm,
                    '--output_files',
                    *enc_pcs,
                    '--checkpoint_dir',
                    model_dir,
                    '--opt_metrics',
                    *opt_metrics,
                    '--max_deltas',
                    *map(str, max_deltas),
                    '--resolution',
                    str(resolution + 1),
                    '--model_config',
                    model_config
                ] + additional_params,
                stdout=f,
                stderr=f,
                check=True)

    # Decoding, skipped with merge_coding option
    if all(os.path.exists(x) for x in dec_pcs):
        print_progress(enc_pcs, dec_pcs, '(exists)')
    elif not no_merge_coding:
        print_progress(enc_pcs, dec_pcs, '(merge_coding)')
    else:
        print_progress(enc_pcs, dec_pcs)
        with ExitStack() as stack:
            if no_stream_redirection:
                f = None
            else:
                f = open(decompress_log, 'w')
                stack.enter_context(f)
            subprocess.run(
                [
                    'python',
                    'decompress_octree.py',  # '--debug',
                    '--input_files',
                    *enc_pcs,
                    '--output_files',
                    *dec_pcs,
                    '--checkpoint_dir',
                    model_dir,
                    '--model_config',
                    model_config
                ],
                stdout=f,
                stderr=f,
                check=True)

    # Color mapping
    mc_params = []
    if all(os.path.exists(x) for x in dec_pcs_color):
        print_progress(dec_pcs, dec_pcs_color, '(exists)')
    else:
        for dp, dpc in zip(dec_pcs, dec_pcs_color):
            print_progress(dp, dpc)
            mc_params.append((input_pc, dp, dpc))
    parallel_process(run_mapcolor, mc_params, num_parallel)

    pcerror_cfg_params = [[f'--{k}', str(v)] for k, v in pcerror_cfg.items()]
    pcerror_cfg_params = flatten(pcerror_cfg_params)
    params = []
    for pcerror_result, decoded_pc in zip(pcerror_results, dec_pcs):
        if os.path.exists(pcerror_result):
            print_progress(decoded_pc, pcerror_result, '(exists)')
        else:
            print_progress(decoded_pc, pcerror_result)
            params.append((decoded_pc, input_norm, input_pc,
                           pcerror_cfg_params, pcerror_path, pcerror_result))
    parallel_process(run_pcerror, params, num_parallel)

    for pcerror_result, enc_pc, decoded_pc, experiment_report in zip(
            pcerror_results, enc_pcs, dec_pcs, exp_reports):
        if os.path.exists(experiment_report):
            print_progress('all', experiment_report, '(exists)')
        else:
            print_progress('all', experiment_report)
            pcerror_data = mpeg_parsing.parse_pcerror(pcerror_result)

            pos_total_size_in_bytes = os.stat(enc_pc).st_size
            input_point_count = len(PyntCloud.from_file(input_pc).points)
            data = {
                'pos_total_size_in_bytes': pos_total_size_in_bytes,
                'pos_bits_per_input_point':
                pos_total_size_in_bytes * 8 / input_point_count,
                'input_point_count': input_point_count
            }
            data = {**data, **pcerror_data}
            with open(experiment_report, 'w') as f:
                json.dump(data, f, sort_keys=True, indent=4)

            # Debug
            with open(enc_pc + '.enc.metric.json', 'r') as f:
                enc_metrics = json.load(f)
            diff = abs(enc_metrics['d1_psnr'] - data['d1_psnr'])
            logger.info(f'D1 PSNR diff between encoder and decoder: {diff}')
            assert diff < 0.01, f'encoded {enc_pc} with D1 {enc_metrics["d1_psnr"]} but decoded {decoded_pc} with D1 {data["d1_psnr"]}dB'

    logger.info('Done')