Ejemplo n.º 1
0
def get_prior(prior_type):
  if prior_type == 'normal':
    prior = flows.Normal()
  else:
    means, covariances, weights = get_gmm_params(prior)
    prior = flows.GMM(means, covariances, weights)
  return prior
Ejemplo n.º 2
0
    def test_normal(self):
        inputs = random.uniform(random.PRNGKey(0), (20, 2),
                                minval=-3.0,
                                maxval=3.0)
        input_dim = inputs.shape[1]
        init_key, sample_key = random.split(random.PRNGKey(0))

        init_fun = flows.Normal()
        params, log_pdf, sample = init_fun(init_key, input_dim)
        log_pdfs = log_pdf(params, inputs)

        mean = np.zeros(input_dim)
        covariance = np.eye(input_dim)
        true_log_pdfs = multivariate_normal.logpdf(inputs, mean, covariance)

        self.assertTrue(np.allclose(log_pdfs, true_log_pdfs))

        for test in (returns_correct_shape, ):
            test(self, flows.Normal())
Ejemplo n.º 3
0
 def test_flow(self):
     for test in (returns_correct_shape, ):
         test(self, flows.Flow(flows.Reverse(), flows.Normal()))
Ejemplo n.º 4
0
def learn_flow(X,
               hidden_dim=48,
               num_hidden=2,
               num_unit=5,
               learning_rate=1e-3,
               num_epochs=200,
               batch_size=4000,
               interval=None,
               seed=123):
    """Training with 400K of riz data works well the default args.
    Make sure to remove samples with undetected flux since these otherwise create a delta function.
    """
    # Preprocess input data.
    scaler = preprocessing.StandardScaler().fit(X)
    X_preproc = scaler.transform(X)
    input_dim = X.shape[1]

    # Initialize random numbers.
    rng, flow_rng = jax.random.split(jax.random.PRNGKey(seed))

    # Initialize our flow bijection.
    transform = functools.partial(masked_transform,
                                  hidden_dim=hidden_dim,
                                  num_hidden=num_hidden)
    bijection_init_fun = flows.Serial(
        *(flows.MADE(transform), flows.Reverse()) * num_unit)

    # Create direct and inverse bijection functions.
    rng, bijection_rng = jax.random.split(rng)
    bijection_params, bijection_direct, bijection_inverse = bijection_init_fun(
        bijection_rng, input_dim)

    # Initialize our flow model.
    prior_init_fun = flows.Normal()
    flow_init_fun = flows.Flow(bijection_init_fun, prior_init_fun)
    initial_params, log_pdf, sample = flow_init_fun(flow_rng, input_dim)

    if interval is not None:
        import matplotlib.pyplot as plt
        bins = np.linspace(-0.05, 1.05, 111)

    def loss_fn(params, inputs):
        return -log_pdf(params, inputs).mean()

    @jax.jit
    def step(i, opt_state, inputs):
        params = get_params(opt_state)
        loss_value, gradients = jax.value_and_grad(loss_fn)(params, inputs)
        return opt_update(i, gradients, opt_state), loss_value

    opt_init, opt_update, get_params = optimizers.adam(step_size=learning_rate)
    opt_state = opt_init(initial_params)
    root2 = np.sqrt(2.)

    itercount = itertools.count()
    for epoch in range(num_epochs):
        rng, permute_rng = jax.random.split(rng)
        X_epoch = jax.random.permutation(permute_rng, X_preproc)
        for batch_index in range(0, len(X), batch_size):
            opt_state, loss = step(
                next(itercount), opt_state,
                X_epoch[batch_index:batch_index + batch_size])
        if interval is not None and (epoch + 1) % interval == 0:
            print(f'epoch {epoch + 1} loss {loss:.3f}')
            # Map the input data back through the flow to the prior space.
            epoch_params = get_params(opt_state)
            X_normal, _ = bijection_direct(epoch_params, X_epoch)
            X_uniform = 0.5 * (
                1 + scipy.special.erf(np.array(X_normal, np.float64) / root2))
            for i in range(input_dim):
                plt.hist(X_uniform[:, i], bins, histtype='step')
            plt.show()

    # Return a function that maps samples to a ~uniform distribution on [0,1] ** input_dim.
    # Takes a numpy array as input and returns a numpy array of the same shape.
    final_params = get_params(opt_state)

    def flow_map(Y):
        Y_preproc = scaler.transform(Y)
        Y_normal, _ = bijection_direct(final_params, jnp.array(Y_preproc))
        #return np.array(Y_normal)
        Y_uniform = 0.5 * (
            1 + scipy.special.erf(np.array(Y_normal, np.float64) / root2))
        return Y_uniform.astype(np.float32)

    return flow_map
Ejemplo n.º 5
0
    num_blocks = int(config['num_blocks'])
    num_hidden = int(config['num_hidden'])
    private = str(config['private']).lower() == 'true'
    sampling = config['sampling'].lower()

    X, X_val, X_test = utils.get_datasets(dataset)
    num_samples, input_dim = X.shape
    delta = 1e-4 if dataset == 'lifesci' else 1 / num_samples

    shutil.copyfile(flow_path + 'flow_utils.py', 'analysis/flow_utils.py')
    from analysis import flow_utils

    modules = flow_utils.get_modules(flow, num_blocks, normalization,
                                     num_hidden)
    bijection = flows.Serial(*tuple(modules))
    prior = flows.Normal()
    init_fun = flows.Flow(bijection, prior)
    temp_key, key = random.split(key)
    _, log_pdf, sample = init_fun(temp_key, input_dim)

    iterations = sorted([
        int(d) for d in os.listdir(flow_path) if os.path.isdir(flow_path + d)
    ])

    print('δ = {}'.format(delta))
    for composition in ['gdp', 'ma']:
        print('Composing in {}...'.format(composition))
        for iteration in iterations:
            epsilon = utils.get_epsilon(
                private,
                composition,