예제 #1
0
    def testDefaultConfigFileParser_Basic(self):
        p = configargparse.DefaultConfigFileParser()
        self.assertTrue(len(p.get_syntax_description()) > 0)

        # test the simplest case
        input_config_str = StringIO("""a: 3\n""")
        parsed_obj = p.parse(input_config_str)
        output_config_str = p.serialize(parsed_obj)

        self.assertEqual(input_config_str.getvalue().replace(": ", " = "),
                         output_config_str)

        self.assertDictEqual(parsed_obj, dict([('a', '3')]))
예제 #2
0
def load_env(env_file: str = '.env', sh_exp: bool = True):
    """
    load .env file, and set as environment variables
    :param env_file: env filename (default ".env")
    :param sh_exp: apply sh expansion (default True)
    """
    env_conf = configargparse.DefaultConfigFileParser()
    if sh_exp:
        sh_result = subprocess.run(["bash", "-x", env_file],
                                   stdout=subprocess.PIPE,
                                   stderr=subprocess.PIPE,
                                   encoding='ascii')
        env_str = sh_result.stderr.replace("+ ", "")
        sio = io.StringIO("\n".join(shlex.split(env_str)))
        envs = env_conf.parse(sio)
    else:
        with open(env_file) as conf_file:
            envs = env_conf.parse(conf_file)
    envs.update(os.environ)  # don't overwrite original environment
    os.environ.update(envs)
예제 #3
0
    def testDefaultConfigFileParser_All(self):
        p = configargparse.DefaultConfigFileParser()

        # test the all syntax case
        config_lines = [
            "# comment1 ",
            "[ some section ]",
            "----",
            "---------",
            "_a: 3",
            "; comment2 ",
            "_b = c",
            "_list_arg1 = [a, b, c]",
            "_str_arg = true",
            "_list_arg2 = [1, 2, 3]",
        ]

        # test parse
        input_config_str = StringIO("\n".join(config_lines) + "\n")
        parsed_obj = p.parse(input_config_str)

        # test serialize
        output_config_str = p.serialize(parsed_obj)
        self.assertEqual(
            "\n".join(
                l.replace(': ', ' = ')
                for l in config_lines if l.startswith('_')) + "\n",
            output_config_str)

        self.assertDictEqual(
            parsed_obj,
            dict([
                ('_a', '3'),
                ('_b', 'c'),
                ('_list_arg1', ['a', 'b', 'c']),
                ('_str_arg', 'true'),
                ('_list_arg2', ['1', '2', '3']),
            ]))

        self.assertListEqual(parsed_obj['_list_arg1'], ['a', 'b', 'c'])
        self.assertListEqual(parsed_obj['_list_arg2'], ['1', '2', '3'])
예제 #4
0
def render_model():
    # get model paths
    path, epoch = opt.render_model
    epoch = int(epoch)
    assert (os.path.isdir(path))
    assert (os.path.isfile(path + '/config.ini'))

    p = configargparse.DefaultConfigFileParser()
    with open(path + '/config.ini') as f:
        args = p.parse(f)
    opt.hidden_layers = int(args['hidden_layers'])
    opt.hidden_features = int(args['hidden_features'])
    opt.use_piecewise_model = args['use_piecewise_model'] == 'true'
    opt.use_grad = args['use_grad'] == 'true'
    opt.activation = args['activation']
    opt.normalize_pe = args['normalize_pe'] == 'true'
    opt.img_size = int(args['img_size'])
    opt.num_cuts = int(args['num_cuts'])
    opt.use_sampler = args['use_sampler'] == 'true'
    opt.dataset = args['dataset']

    if opt.dataset == 'deepvoxels':
        dataset = dataio.DeepVoxelDataset(opt.dv_dataset_path,
                                          mode='test',
                                          resize_to=2 * (opt.img_size, ))
        use_ndc = False
    elif opt.dataset == 'llff':
        dataset = dataio.LLFFDataset(opt.llff_dataset_path,
                                     mode='test',
                                     final_render=False)
        use_ndc = True
    elif opt.dataset == 'blender':
        dataset = dataio.NerfBlenderDataset(opt.nerf_dataset_path,
                                            splits=['test'],
                                            mode='test',
                                            select_idx=opt.render_select_idx,
                                            resize_to=2 * (opt.img_size, ))
        use_ndc = False
    else:
        raise NotImplementedError('dataset not implemented')

    if opt.use_sampler:
        cam_params = dataset.get_camera_params()
        sampler = modules.SamplingNet(Nt=opt.samples_per_ray,
                                      ncuts=opt.num_cuts,
                                      sampling_interval=(cam_params['near'],
                                                         cam_params['far']))
    else:
        sampler = None

    add_pe_ray_samples = 10  # 10 cos + sin
    add_pe_orientations = 4  # 4 cos + sin

    model_sigma = modules.RadianceNet(
        out_features=1,
        hidden_layers=opt.hidden_layers,
        hidden_features=opt.hidden_features,
        nl=opt.activation,
        use_grad=opt.use_grad,
        input_name=['ray_samples', 'ray_orientations'],
        input_processing_fn=modules.input_processing_fn,
        input_pe_params={
            'ray_samples': add_pe_ray_samples,
            'ray_orientations': add_pe_orientations
        },
        sampler=sampler,
        normalize_pe=opt.normalize_pe)

    ckpt_dict = torch.load(path + '/checkpoints/' +
                           f'model_sigma_epoch_{epoch:04d}.pth')
    state_dict = translate_saved_weights(ckpt_dict, model_sigma)
    model_sigma.load_state_dict(state_dict, strict=True)
    model_sigma.eval()
    model_sigma.cuda()

    model_rgb = modules.RadianceNet(
        out_features=3,
        hidden_layers=opt.hidden_layers,
        hidden_features=opt.hidden_features,
        nl=opt.activation,
        use_grad=opt.use_grad,
        input_name=['ray_samples', 'ray_orientations'],
        input_processing_fn=modules.input_processing_fn,
        input_pe_params={
            'ray_samples': add_pe_ray_samples,
            'ray_orientations': add_pe_orientations
        },
        sampler=sampler,
        normalize_pe=opt.normalize_pe)

    ckpt_dict = torch.load(path + '/checkpoints/' +
                           f'model_rgb_epoch_{epoch:04d}.pth')
    state_dict = translate_saved_weights(ckpt_dict, model_rgb)
    model_rgb.load_state_dict(state_dict, strict=True)
    model_rgb.eval()
    model_rgb.cuda()

    models = {'sigma': model_sigma, 'rgb': model_rgb}

    # set up dataset
    coords_dataset = dataio.Implicit6DMultiviewDataWrapper(
        dataset,
        dataset.get_img_shape(),
        dataset.get_camera_params(),
        samples_per_ray=opt.samples_per_ray,
        samples_per_view=np.prod(dataset.get_img_shape()[:2]),
        use_ndc=use_ndc)
    coords_dataset.toggle_logging_sampling()

    if opt.render_output is None:
        output_path = path + '/render'
    else:
        output_path = opt.render_output

    utils.render_views(output_path,
                       models,
                       coords_dataset,
                       use_piecewise_model=opt.use_piecewise_model,
                       num_cuts=opt.num_cuts,
                       use_sampler=opt.use_sampler,
                       integral_render=True,
                       chunk_size=opt.chunk_size_train,
                       video=False)
    sys.exit()
예제 #5
0
def read_config_file(file_path):
    config_file = configargparse.DefaultConfigFileParser()
    with open(file_path, 'r') as f:
        args = config_file.parse(f)
        print(args)
    return args
예제 #6
0
args = parser.parse_args()

# Display all of values - and where they are coming from
print(parser.format_values())

WINDOW_SIZE = args.window_size
EMBEDDING_SIZE = args.embedding_size
DENSE_SIZE = args.dense_size

if args.create_config:
    options = {}
    for attr, value in args.__dict__.items():
        if attr != "config" and attr != "create_config" and value is not None:
            options[attr] = value
    file_name = args.create_config
    content = configargparse.DefaultConfigFileParser().serialize(options)
    Path(file_name).write_text(content)
    print("configuration saved to file: %s" % file_name)
    sys.exit(0)

import tensorflow as tf
import tensorflow.keras as K
import tensorflow_datasets as tfds


# Implement simple SpaceTokenizer - built-in tokenizer in tf filter-out
# non alphanumeric tokens
# see https://www.tensorflow.org/datasets/api_docs/python/tfds/features/text/TokenTextEncoder
class SpaceTokenizer(object):
    def tokenize(self, s):
        toks = []