예제 #1
0
파일: render.py 프로젝트: zenna/reverseflow
def main(argv):
    options = {
        'batch_size': 128,
        'max_time': 100.0,
        'logdir': '/home/zenna/repos/inverse/log',
        'template': template_dict,
        'nnet_enhanced_pi': False,
        'pointwise_pi': True,
        'min_fx_y': False,
        'nnet': True,
        'min_fx_param': False,
        'rightinv_pi_fx': False,
        'nruns': 2
    }
    min_param_size = 10
    param_types = {
        'theta':
        tensor_type(dtype=tf.float32,
                    shape=(options['batch_size'], min_param_size),
                    name="shrunk_param")
    }

    param_gen = {
        k: infinite_samples(np.random.rand, v['shape'])
        for k, v in param_types.items()
    }
    shrunk_param_gen = dictionary_gen(param_gen)
    return compare(render_gen_graph, render_fwd_f, param_types,
                   shrunk_param_gen, options)
예제 #2
0
def main(argv):
    global stats
    options = {
        'batch_size': 512,
        'max_time': 5.0,
        'logdir': '/home/zenna/repos/inverse/log',
        'template': template_dict,
        'nnet_enhanced_pi': True,
        'pointwise_pi': True,
        'min_fx_y': True,
        'nnet': True
    }
    gen_graph = rand_gen_graph
    fwd_f = rand_fwd_f
    min_param_size = 1
    param_types = {
        'theta':
        tensor_type(dtype=tf.float32,
                    shape=(options['batch_size'], min_param_size),
                    name="shrunk_param")
    }

    param_gen = {
        k: infinite_samples(np.random.rand, v['shape'])
        for k, v in param_types.items()
    }
    np.random.seed(0)
    shrunk_param_gen = dictionary_gen(param_gen)
    np.random.seed(0)
    stats = compare(gen_graph, rand_fwd_f, param_types, shrunk_param_gen,
                    options)
    return stats