コード例 #1
0
ファイル: config.py プロジェクト: b7leung/occ_uda
def get_data_fields(mode, cfg):
    ''' Returns the data fields.

    Args:
        mode (str): the mode which is used
        cfg (dict): imported yaml config
    '''
    points_transform = data.SubsamplePoints(cfg['data']['points_subsample'])
    with_transforms = cfg['model']['use_camera']

    fields = {}
    fields['points'] = data.PointsField(
        cfg['data']['points_file'],
        points_transform,
        with_transforms=with_transforms,
        unpackbits=cfg['data']['points_unpackbits'],
    )

    if mode in ('val', 'test'):
        points_iou_file = cfg['data']['points_iou_file']
        voxels_file = cfg['data']['voxels_file']
        if points_iou_file is not None:
            fields['points_iou'] = data.PointsField(
                points_iou_file,
                with_transforms=with_transforms,
                unpackbits=cfg['data']['points_unpackbits'],
            )
        if voxels_file is not None:
            fields['voxels'] = data.VoxelsField(voxels_file)

    return fields
コード例 #2
0
def get_transforms(cfg):
    ''' Returns transforms.

    Args:
        cfg (yaml): yaml config
    '''
    n_pt = cfg['data']['n_training_points']
    n_pt_eval = cfg['training']['n_eval_points']
    transf_pt = data.SubsamplePoints(n_pt)
    transf_pt_val = data.SubsamplePointsSeq(n_pt_eval, random=False)

    return transf_pt, transf_pt_val
コード例 #3
0
def get_transforms(cfg):
    ''' Returns transform objects.

    Args:
        cfg (yaml config): yaml config object
    '''
    n_pcl = cfg['data']['n_training_pcl_points']
    n_pt = cfg['data']['n_training_points']
    n_pt_eval = cfg['training']['n_eval_points']

    transf_pt = data.SubsamplePoints(n_pt)
    transf_pt_val = data.SubsamplePointsSeq(n_pt_eval, random=False)
    transf_pcl_val = data.SubsamplePointcloudSeq(n_pt_eval, random=False)
    transf_pcl = data.SubsamplePointcloudSeq(n_pcl, connected_samples=True)

    return transf_pt, transf_pt_val, transf_pcl, transf_pcl_val
コード例 #4
0
ファイル: config.py プロジェクト: ykkawana/neural_star_domain
def get_data_fields(mode, cfg):
    ''' Returns the data fields.

    Args:
        mode (str): the mode which is used
        cfg (dict): imported yaml config
    '''
    points_transform = data.SubsamplePoints(cfg['data']['points_subsample'])
    if cfg.get('sdf_generation', False):
        points_transform = None
    with_transforms = cfg['model']['use_camera']

    fields = {}
    fields['points'] = data.PointsField(
        cfg['data']['points_file'],
        points_transform,
        with_transforms=with_transforms,
        unpackbits=cfg['data']['points_unpackbits'],
    )

    if not cfg.get('sdf_generation', False) and cfg['trainer'].get(
            'is_sdf', False):
        sdf_points_transform = data.SubsampleSDFPoints(
            cfg['data']['points_subsample'])
        fields['sdf_points'] = data.SDFPointsField(
            cfg['data']['sdf_points_file'],
            sdf_points_transform,
            with_transforms=with_transforms)

    pointcloud_transform = data.SubsamplePointcloud(
        cfg['data']['pointcloud_target_n'])
    if cfg.get('sdf_generation', False):
        pointcloud_transform = None

    fields['pointcloud'] = data.PointCloudField(cfg['data']['pointcloud_file'],
                                                pointcloud_transform,
                                                with_transforms=True)
    fields['angles'] = nsd_data.SphericalCoordinateField(
        cfg['data']['primitive_points_sample_n'],
        mode,
        is_normal_icosahedron=cfg['data'].get('is_normal_icosahedron', False),
        is_normal_uv_sphere=cfg['data'].get('is_normal_uv_sphere', False),
        icosahedron_subdiv=cfg['data'].get('icosahedron_subdiv', 2),
        icosahedron_uv_margin=cfg['data'].get('icosahedron_uv_margin', 1e-5),
        icosahedron_uv_margin_phi=cfg['data'].get('icosahedron_uv_margin_phi',
                                                  1e-5),
        uv_sphere_length=cfg['data'].get('uv_sphere_length', 20),
        normal_mesh_no_invert=cfg['data'].get('normal_mesh_no_invert', False))
    if mode in ('val', 'test'):
        points_iou_file = cfg['data']['points_iou_file']
        voxels_file = cfg['data']['voxels_file']
        if points_iou_file is not None:
            fields['points_iou'] = data.PointsField(
                points_iou_file,
                with_transforms=with_transforms,
                unpackbits=cfg['data']['points_unpackbits'],
            )
        if voxels_file is not None:
            fields['voxels'] = data.VoxelsField(voxels_file)

    return fields
コード例 #5
0
def get_occ_data_fields(mode, cfg):
    ''' Returns the data fields.

    Args:
        mode (str): the mode which is used
        cfg (dict): imported yaml config
    '''
    N = cfg['data']['points_subsample']
    points_transform = data.SubsamplePoints(cfg['data']['points_subsample'])
    with_transforms = cfg['model']['use_camera']

    if mode == 'train':
        if 'input_range' in cfg['data']:
            input_range = cfg['data']['input_range']
            print('Input range:', input_range)
        else:
            input_range = None
    else:
        if 'test_range' in cfg['data']:
            input_range = cfg['data']['test_range']
            print('Test range:', input_range)
        else:
            input_range = None

    fields = {}
    points_file = cfg['data']['points_file']
    if points_file.endswith('.npz'):
        fields['points'] = data.PointsField(
            cfg['data']['points_file'], points_transform,
            with_transforms=with_transforms,
            unpackbits=cfg['data']['points_unpackbits'],
            input_range=input_range
        )
    elif points_file.endswith('.h5'):
        fields['points'] = data.PointsH5Field(
            cfg['data']['points_file'], subsample_n=N,
            with_transforms=with_transforms,
            input_range=input_range
        )
    else:
        raise NotImplementedError

    if mode in ('val', 'test'):
        points_iou_file = cfg['data']['points_iou_file']
        voxels_file = cfg['data']['voxels_file']
        if points_iou_file is not None:
            if points_iou_file.endswith('.npz'):
                fields['points_iou'] = data.PointsField(
                    points_iou_file,
                    with_transforms=with_transforms,
                    unpackbits=cfg['data']['points_unpackbits'],
                    input_range=input_range
                )
            elif points_iou_file.endswith('.h5'):
                fields['points_iou'] = data.PointsH5Field(
                    points_iou_file, 
                    with_transforms=with_transforms,
                    input_range=input_range
                )
            else:
                raise NotImplementedError

        if voxels_file is not None:
            fields['voxels'] = data.VoxelsField(voxels_file)

    return fields