Exemplo n.º 1
0
Arquivo: config.py Projeto: yifita/DSS
def create_model(cfg, device, mode="train", camera_model=None, **kwargs):
    ''' Returns model

    Args:
        cfg (edict): imported yaml config
        device (device): pytorch device
    '''
    if cfg.model.type == 'point':
        decoder = None

    texture = None
    use_lighting = (cfg.renderer is not None
                    and not cfg.renderer.get('is_neural_texture', True))
    if use_lighting:
        texture = LightingTexture()
    else:
        if 'rgb' not in cfg.model.decoder_kwargs.out_dims:
            Texture = get_class_from_string(cfg.model.texture_type)
            cfg.model.texture_kwargs[
                'c_dim'] = cfg.model.decoder_kwargs.out_dims.get('latent', 0)
            texture_decoder = Texture(**cfg.model.texture_kwargs)
        else:
            texture_decoder = decoder
            logger_py.info("Decoder used as NeuralTexture")

        texture = NeuralTexture(
            view_dependent=cfg.model.texture_kwargs.view_dependent,
            decoder=texture_decoder).to(device=device)
        logger_py.info("Created NeuralTexture {}".format(texture.__class__))
        logger_py.info(texture)

    Model = get_class_from_string("DSS.models.{}_modeling.Model".format(
        cfg.model.type))

    # if not using decoder, then use non-parameterized point renderer
    # create icosphere as initial point cloud
    sphere_mesh = ico_sphere(level=4)
    sphere_mesh.scale_verts_(0.5)
    points, normals = sample_points_from_meshes(
        sphere_mesh,
        num_samples=int(cfg['model']['model_kwargs']['n_points_per_cloud']),
        return_normals=True)
    colors = torch.ones_like(points)
    renderer = create_renderer(cfg.renderer).to(device)
    model = Model(
        points,
        normals,
        colors,
        renderer,
        device=device,
        texture=texture,
        **cfg.model.model_kwargs,
    ).to(device=device)

    return model
Exemplo n.º 2
0
def create_renderer(render_opt):
    """ Create rendere """
    Renderer = get_class_from_string(render_opt.renderer_type)
    Raster = get_class_from_string(render_opt.raster_type)
    i = render_opt.raster_type.rfind('.')
    raster_setting_type = render_opt.raster_type[:i] + \
        '.PointsRasterizationSettings'
    if render_opt.compositor_type is not None:
        Compositor = get_class_from_string(render_opt.compositor_type)
        compositor = Compositor()
    else:
        compositor = None

    RasterSetting = get_class_from_string(raster_setting_type)
    raster_settings = RasterSetting(**render_opt.raster_params)

    renderer = Renderer(
        rasterizer=Raster(cameras=None, raster_settings=raster_settings),
        compositor=compositor,
    )
    return renderer
Exemplo n.º 3
0
def create_splatting_renderer():
    Renderer = get_class_from_string(
        'DSS.core.renderer.SurfaceSplattingRenderer')
    Raster = get_class_from_string('DSS.core.rasterizer.SurfaceSplatting')
    # i = render_opt.raster_type.rfind('.')
    # raster_setting_type = render_opt.raster_type[:i] + \
    #     '.PointsRasterizationSettings'
    if render_opt.compositor_type is not None:
        Compositor = get_class_from_string(
            'pytorch3d.renderer.NormWeightedCompositor')
        compositor = Compositor()
    else:
        compositor = None

    raster_params = {
        'backface_culling': False,
        'Vrk_invariant': True,
        'Vrk_isotropic': False,
        'bin_size': None,
        'clip_pts_grad': 0.05,
        'cutoff_threshold': 1.0,
        'depth_merging_threshold': 0.05,
        'image_size': 512,
        'max_points_per_bin': None,
        'points_per_pixel': 5,
        'radii_backward_scaler': 5,
    }

    # RasterSetting = get_class_from_string(raster_setting_type)
    RasterSetting = get_class_from_string(
        'DSS.core.rasterizer.PointsRasterizationSettings')
    raster_settings = RasterSetting(**raster_params)

    renderer = Renderer(
        rasterizer=Raster(cameras=FoVPerspectiveCameras(),
                          raster_settings=raster_settings),
        compositor=compositor,
    )
    return renderer
Exemplo n.º 4
0
Arquivo: config.py Projeto: yifita/DSS
def create_generator(cfg, model, device, **kwargs):
    ''' Returns the generator object.

    Args:
        model (nn.Module): model
        cfg (dict): imported yaml config
        device (device): pytorch device
    '''
    Generator = get_class_from_string(
        'DSS.models.{}_modeling.Generator'.format(cfg.model.type))

    generator = Generator(model,
                          device,
                          threshold=cfg['test']['threshold'],
                          **cfg.generation)
    return generator
Exemplo n.º 5
0
def create_model(cfg, device, mode="train", camera_model=None, **kwargs):
    ''' Returns model

    Args:
        cfg (edict): imported yaml config
        device (device): pytorch device
    '''
    decoder = cfg['model']['decoder']
    encoder = cfg['model']['encoder']

    if mode == 'test' and cfg.model.type == 'combined':
        cfg.model.type = 'implicit'

    if cfg.model.type == 'point':
        decoder = None

    if decoder is not None:
        c_dim = cfg['model']['c_dim']
        Decoder = get_class_from_string(cfg.model.decoder_type)
        decoder = Decoder(c_dim=c_dim, dim=3,
                          **cfg.model.decoder_kwargs).to(device=device)
        logger_py.info("Created Decoder {}".format(decoder.__class__))
        logger_py.info(decoder)
        # initialize siren model to be a sphere
        if cfg.model.decoder_type == 'DSS.models.common.Siren':
            decoder_kwargs = cfg.model.decoder_kwargs
            if cfg.training.init_siren_from_sphere:
                try:
                    pretrained_model_file = os.path.join(
                        'data', 'trained_model', 'siren_l{}_c{}_o{}.pt'.format(
                            decoder_kwargs.n_layers,
                            decoder_kwargs.hidden_size,
                            decoder_kwargs.first_omega_0))
                    loaded_state_dict = torch.load(pretrained_model_file)
                    decoder.load_state_dict(loaded_state_dict)
                    logger_py.info('initialized Siren decoder with {}'.format(
                        pretrained_model_file))
                except Exception:
                    pass

    texture = None
    use_lighting = (cfg.renderer is not None
                    and not cfg.renderer.get('is_neural_texture', True))
    if use_lighting:
        texture = LightingTexture()
    else:
        if 'rgb' not in cfg.model.decoder_kwargs.out_dims:
            Texture = get_class_from_string(cfg.model.texture_type)
            cfg.model.texture_kwargs[
                'c_dim'] = cfg.model.decoder_kwargs.out_dims.get('latent', 0)
            texture_decoder = Texture(**cfg.model.texture_kwargs)
        else:
            texture_decoder = decoder
            logger_py.info("Decoder used as NeuralTexture")

        texture = NeuralTexture(
            view_dependent=cfg.model.texture_kwargs.view_dependent,
            decoder=texture_decoder).to(device=device)
        logger_py.info("Created NeuralTexture {}".format(texture.__class__))
        logger_py.info(texture)

    Model = get_class_from_string("DSS.models.{}_modeling.Model".format(
        cfg.model.type))

    if cfg.model.type == 'implicit':
        model = Model(decoder,
                      renderer=None,
                      texture=texture,
                      encoder=encoder,
                      cameras=camera_model,
                      device=device,
                      **cfg.model.model_kwargs)

    elif cfg.model.type == 'combined':
        renderer = create_renderer(cfg.renderer).to(device)
        # TODO: load
        points = None
        point_file = os.path.join(cfg.training.out_dir, cfg.name,
                                  cfg.training.point_file)
        if os.path.isfile(point_file):
            # load point or mesh then sample
            loaded_shape = trimesh.load(point_file)
            if isinstance(loaded_shape, trimesh.PointCloud):
                # overide n_points_per_cloud
                cfg.model.model_kwargs.n_points_per_cloud = loaded_shape.vertices.shape[
                    0]
                points = loaded_shape.vertices
            else:
                n_points = cfg.model.model_kwargs['n_points_per_cloud']
                try:
                    # reject sampling can produce less points, hence sample more
                    points = trimesh.sample.sample_surface_even(loaded_shape,
                                                                int(n_points *
                                                                    1.1),
                                                                radius=0.01)[0]
                    p_idx = np.random.permutation(
                        loaded_shape.vertices.shape[0])[:n_points]
                    points = points[p_idx, ...]
                except Exception:
                    # randomly
                    p_idx = np.random.permutation(
                        loaded_shape.vertices.shape[0])[:n_points]
                    points = loaded_shape.vertices[p_idx, ...]

            points = torch.tensor(points, dtype=torch.float, device=device)

        model = Model(decoder,
                      renderer,
                      texture=texture,
                      encoder=encoder,
                      cameras=camera_model,
                      device=device,
                      points=points,
                      **cfg.model.model_kwargs)

    else:
        ValueError('model type must be combined|point|implicit|occupancy')

    return model