Exemple #1
0
def save_config(path, config):
    """
    Save config dictionary as json file
    """
    out_dir = os.path.dirname(path)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    if os.path.isfile(path):
        logger_py.warn(
            "Found file existing in {}, overwriting the existing file.".format(
                out_dir))

    with open(path, 'w') as f:
        yaml.dump(config, f, sort_keys=False)

    logger_py.info("Saved config to {}".format(path))
Exemple #2
0
def create_animation(pts_dir, show_max=-1):
    figs = []
    # points
    pts_files = [
        f for f in os.listdir(pts_dir)
        if 'pts' in f and f[-4:].lower() in ('.ply', 'obj')
    ]
    if len(pts_files) == 0:
        logger_py.info("Couldn't find '*pts*' files in {}".format(pts_dir))
    else:
        pts_files.sort()
        if show_max > 0:
            pts_files = pts_files[::max(len(pts_files) // show_max, 1)]
        pts_names = list(
            map(lambda x: os.path.basename(x)[:-4].split('_')[0], pts_files))
        pts_paths = [os.path.join(pts_dir, fname) for fname in pts_files]
        fig = animate_points(pts_paths, pts_names)
        figs.append(fig)
    # mesh
    mesh_files = [
        f for f in os.listdir(pts_dir)
        if 'mesh' in f and f[-4:].lower() in ('.ply', '.obj')
    ]
    # mesh_files = list(filter(lambda x: x.split('_')
    #                          [1] == '000.obj', mesh_files))
    if len(mesh_files) == 0:
        logger_py.info("Couldn't find '*mesh*' files in {}".format(pts_dir))
    else:
        mesh_files.sort()
        if show_max > 0:
            mesh_files = mesh_files[::max(len(mesh_files) // show_max, 1)]
        mesh_names = list(
            map(lambda x: os.path.basename(x)[:-4].split('_')[0], mesh_files))
        mesh_paths = [os.path.join(pts_dir, fname) for fname in mesh_files]
        fig = animate_mesh(mesh_paths, mesh_names)
        figs.append(fig)

    save_html = os.path.join(pts_dir, 'animation.html')
    os.makedirs(os.path.dirname(save_html), exist_ok=True)
    figures_to_html(figs, save_html)
Exemple #3
0
def create_model(cfg, device, mode="train", camera_model=None, **kwargs):
    ''' Returns model

    Args:
        cfg (edict): imported yaml config
        device (device): pytorch device
    '''
    if cfg.model.type == 'point':
        decoder = None

    texture = None
    use_lighting = (cfg.renderer is not None
                    and not cfg.renderer.get('is_neural_texture', True))
    if use_lighting:
        texture = LightingTexture()
    else:
        if 'rgb' not in cfg.model.decoder_kwargs.out_dims:
            Texture = get_class_from_string(cfg.model.texture_type)
            cfg.model.texture_kwargs[
                'c_dim'] = cfg.model.decoder_kwargs.out_dims.get('latent', 0)
            texture_decoder = Texture(**cfg.model.texture_kwargs)
        else:
            texture_decoder = decoder
            logger_py.info("Decoder used as NeuralTexture")

        texture = NeuralTexture(
            view_dependent=cfg.model.texture_kwargs.view_dependent,
            decoder=texture_decoder).to(device=device)
        logger_py.info("Created NeuralTexture {}".format(texture.__class__))
        logger_py.info(texture)

    Model = get_class_from_string("DSS.models.{}_modeling.Model".format(
        cfg.model.type))

    # if not using decoder, then use non-parameterized point renderer
    # create icosphere as initial point cloud
    sphere_mesh = ico_sphere(level=4)
    sphere_mesh.scale_verts_(0.5)
    points, normals = sample_points_from_meshes(
        sphere_mesh,
        num_samples=int(cfg['model']['model_kwargs']['n_points_per_cloud']),
        return_normals=True)
    colors = torch.ones_like(points)
    renderer = create_renderer(cfg.renderer).to(device)
    model = Model(
        points,
        normals,
        colors,
        renderer,
        device=device,
        texture=texture,
        **cfg.model.model_kwargs,
    ).to(device=device)

    return model
Exemple #4
0
def create_model(cfg, device, mode="train", camera_model=None, **kwargs):
    ''' Returns model

    Args:
        cfg (edict): imported yaml config
        device (device): pytorch device
    '''
    decoder = cfg['model']['decoder']
    encoder = cfg['model']['encoder']

    if mode == 'test' and cfg.model.type == 'combined':
        cfg.model.type = 'implicit'

    if cfg.model.type == 'point':
        decoder = None

    if decoder is not None:
        c_dim = cfg['model']['c_dim']
        Decoder = get_class_from_string(cfg.model.decoder_type)
        decoder = Decoder(c_dim=c_dim, dim=3,
                          **cfg.model.decoder_kwargs).to(device=device)
        logger_py.info("Created Decoder {}".format(decoder.__class__))
        logger_py.info(decoder)
        # initialize siren model to be a sphere
        if cfg.model.decoder_type == 'DSS.models.common.Siren':
            decoder_kwargs = cfg.model.decoder_kwargs
            if cfg.training.init_siren_from_sphere:
                try:
                    pretrained_model_file = os.path.join(
                        'data', 'trained_model', 'siren_l{}_c{}_o{}.pt'.format(
                            decoder_kwargs.n_layers,
                            decoder_kwargs.hidden_size,
                            decoder_kwargs.first_omega_0))
                    loaded_state_dict = torch.load(pretrained_model_file)
                    decoder.load_state_dict(loaded_state_dict)
                    logger_py.info('initialized Siren decoder with {}'.format(
                        pretrained_model_file))
                except Exception:
                    pass

    texture = None
    use_lighting = (cfg.renderer is not None
                    and not cfg.renderer.get('is_neural_texture', True))
    if use_lighting:
        texture = LightingTexture()
    else:
        if 'rgb' not in cfg.model.decoder_kwargs.out_dims:
            Texture = get_class_from_string(cfg.model.texture_type)
            cfg.model.texture_kwargs[
                'c_dim'] = cfg.model.decoder_kwargs.out_dims.get('latent', 0)
            texture_decoder = Texture(**cfg.model.texture_kwargs)
        else:
            texture_decoder = decoder
            logger_py.info("Decoder used as NeuralTexture")

        texture = NeuralTexture(
            view_dependent=cfg.model.texture_kwargs.view_dependent,
            decoder=texture_decoder).to(device=device)
        logger_py.info("Created NeuralTexture {}".format(texture.__class__))
        logger_py.info(texture)

    Model = get_class_from_string("DSS.models.{}_modeling.Model".format(
        cfg.model.type))

    if cfg.model.type == 'implicit':
        model = Model(decoder,
                      renderer=None,
                      texture=texture,
                      encoder=encoder,
                      cameras=camera_model,
                      device=device,
                      **cfg.model.model_kwargs)

    elif cfg.model.type == 'combined':
        renderer = create_renderer(cfg.renderer).to(device)
        # TODO: load
        points = None
        point_file = os.path.join(cfg.training.out_dir, cfg.name,
                                  cfg.training.point_file)
        if os.path.isfile(point_file):
            # load point or mesh then sample
            loaded_shape = trimesh.load(point_file)
            if isinstance(loaded_shape, trimesh.PointCloud):
                # overide n_points_per_cloud
                cfg.model.model_kwargs.n_points_per_cloud = loaded_shape.vertices.shape[
                    0]
                points = loaded_shape.vertices
            else:
                n_points = cfg.model.model_kwargs['n_points_per_cloud']
                try:
                    # reject sampling can produce less points, hence sample more
                    points = trimesh.sample.sample_surface_even(loaded_shape,
                                                                int(n_points *
                                                                    1.1),
                                                                radius=0.01)[0]
                    p_idx = np.random.permutation(
                        loaded_shape.vertices.shape[0])[:n_points]
                    points = points[p_idx, ...]
                except Exception:
                    # randomly
                    p_idx = np.random.permutation(
                        loaded_shape.vertices.shape[0])[:n_points]
                    points = loaded_shape.vertices[p_idx, ...]

            points = torch.tensor(points, dtype=torch.float, device=device)

        model = Model(decoder,
                      renderer,
                      texture=texture,
                      encoder=encoder,
                      cameras=camera_model,
                      device=device,
                      points=points,
                      **cfg.model.model_kwargs)

    else:
        ValueError('model type must be combined|point|implicit|occupancy')

    return model
Exemple #5
0
epoch_it = load_dict.get('epoch_it', -1)
it = load_dict.get('it', -1)

# Save config to log directory
config.save_config(os.path.join(out_dir, 'config.yaml'), cfg)

# Update Metrics from loaded
model_selection_metric = cfg['training']['model_selection_metric']
metric_val_best = load_dict.get(
    'loss_val_best', -model_selection_sign * np.inf)

if metric_val_best == np.inf or metric_val_best == -np.inf:
    metric_val_best = -model_selection_sign * np.inf

logger_py.info('Current best validation metric (%s): %.8f'
               % (model_selection_metric, metric_val_best))

# Shorthands
print_every = cfg['training']['print_every']
checkpoint_every = cfg['training']['checkpoint_every']
validate_every = cfg['training']['validate_every']
visualize_every = cfg['training']['visualize_every']
debug_every = cfg['training']['debug_every']
reweight_every = cfg['training']['reweight_every']

scheduler = optim.lr_scheduler.MultiStepLR(
    optimizer, cfg['training']['scheduler_milestones'],
    gamma=cfg['training']['scheduler_gamma'], last_epoch=epoch_it)

# Set mesh extraction to low resolution for fast visuliation
# during training
model = config.create_model(cfg, mode='test', device=device, camera_model=dataset.get_cameras()).to(device=device)

checkpoint_io = CheckpointIO(out_dir, model=model)
checkpoint_io.load(cfg['test']['model_file'])

# Generator
generator = config.create_generator(cfg, model, device=device)

torch.manual_seed(0)

# Generate
with torch.autograd.no_grad():
    model.eval()
    # Generate meshes
    if not args.render_only:
        logger_py.info('Generating mesh...')
        mesh = get_surface_high_res_mesh(lambda x: model.decode(x).sdf.squeeze(), resolution=args.resolution)
        if cfg.data.type == 'DTU':
            mesh.apply_transform(dataset.get_scale_mat())
        mesh_out_file = os.path.join(generation_dir, 'mesh.%s' % mesh_extension)
        mesh.export(mesh_out_file)

    # Generate cuts
    logger_py.info('Generating cross section plots')
    img = generator.generate_iso_contour(imgs_per_cut=5)
    out_file = os.path.join(generation_dir, 'iso')
    img.write_html(out_file + '.html')

    if not args.mesh_only:
        # generate images
        for i, batch in enumerate(test_loader):