def get_prior(prior_type): if prior_type == 'normal': prior = flows.Normal() else: means, covariances, weights = get_gmm_params(prior) prior = flows.GMM(means, covariances, weights) return prior
def test_normal(self): inputs = random.uniform(random.PRNGKey(0), (20, 2), minval=-3.0, maxval=3.0) input_dim = inputs.shape[1] init_key, sample_key = random.split(random.PRNGKey(0)) init_fun = flows.Normal() params, log_pdf, sample = init_fun(init_key, input_dim) log_pdfs = log_pdf(params, inputs) mean = np.zeros(input_dim) covariance = np.eye(input_dim) true_log_pdfs = multivariate_normal.logpdf(inputs, mean, covariance) self.assertTrue(np.allclose(log_pdfs, true_log_pdfs)) for test in (returns_correct_shape, ): test(self, flows.Normal())
def test_flow(self): for test in (returns_correct_shape, ): test(self, flows.Flow(flows.Reverse(), flows.Normal()))
def learn_flow(X, hidden_dim=48, num_hidden=2, num_unit=5, learning_rate=1e-3, num_epochs=200, batch_size=4000, interval=None, seed=123): """Training with 400K of riz data works well the default args. Make sure to remove samples with undetected flux since these otherwise create a delta function. """ # Preprocess input data. scaler = preprocessing.StandardScaler().fit(X) X_preproc = scaler.transform(X) input_dim = X.shape[1] # Initialize random numbers. rng, flow_rng = jax.random.split(jax.random.PRNGKey(seed)) # Initialize our flow bijection. transform = functools.partial(masked_transform, hidden_dim=hidden_dim, num_hidden=num_hidden) bijection_init_fun = flows.Serial( *(flows.MADE(transform), flows.Reverse()) * num_unit) # Create direct and inverse bijection functions. rng, bijection_rng = jax.random.split(rng) bijection_params, bijection_direct, bijection_inverse = bijection_init_fun( bijection_rng, input_dim) # Initialize our flow model. prior_init_fun = flows.Normal() flow_init_fun = flows.Flow(bijection_init_fun, prior_init_fun) initial_params, log_pdf, sample = flow_init_fun(flow_rng, input_dim) if interval is not None: import matplotlib.pyplot as plt bins = np.linspace(-0.05, 1.05, 111) def loss_fn(params, inputs): return -log_pdf(params, inputs).mean() @jax.jit def step(i, opt_state, inputs): params = get_params(opt_state) loss_value, gradients = jax.value_and_grad(loss_fn)(params, inputs) return opt_update(i, gradients, opt_state), loss_value opt_init, opt_update, get_params = optimizers.adam(step_size=learning_rate) opt_state = opt_init(initial_params) root2 = np.sqrt(2.) itercount = itertools.count() for epoch in range(num_epochs): rng, permute_rng = jax.random.split(rng) X_epoch = jax.random.permutation(permute_rng, X_preproc) for batch_index in range(0, len(X), batch_size): opt_state, loss = step( next(itercount), opt_state, X_epoch[batch_index:batch_index + batch_size]) if interval is not None and (epoch + 1) % interval == 0: print(f'epoch {epoch + 1} loss {loss:.3f}') # Map the input data back through the flow to the prior space. epoch_params = get_params(opt_state) X_normal, _ = bijection_direct(epoch_params, X_epoch) X_uniform = 0.5 * ( 1 + scipy.special.erf(np.array(X_normal, np.float64) / root2)) for i in range(input_dim): plt.hist(X_uniform[:, i], bins, histtype='step') plt.show() # Return a function that maps samples to a ~uniform distribution on [0,1] ** input_dim. # Takes a numpy array as input and returns a numpy array of the same shape. final_params = get_params(opt_state) def flow_map(Y): Y_preproc = scaler.transform(Y) Y_normal, _ = bijection_direct(final_params, jnp.array(Y_preproc)) #return np.array(Y_normal) Y_uniform = 0.5 * ( 1 + scipy.special.erf(np.array(Y_normal, np.float64) / root2)) return Y_uniform.astype(np.float32) return flow_map
num_blocks = int(config['num_blocks']) num_hidden = int(config['num_hidden']) private = str(config['private']).lower() == 'true' sampling = config['sampling'].lower() X, X_val, X_test = utils.get_datasets(dataset) num_samples, input_dim = X.shape delta = 1e-4 if dataset == 'lifesci' else 1 / num_samples shutil.copyfile(flow_path + 'flow_utils.py', 'analysis/flow_utils.py') from analysis import flow_utils modules = flow_utils.get_modules(flow, num_blocks, normalization, num_hidden) bijection = flows.Serial(*tuple(modules)) prior = flows.Normal() init_fun = flows.Flow(bijection, prior) temp_key, key = random.split(key) _, log_pdf, sample = init_fun(temp_key, input_dim) iterations = sorted([ int(d) for d in os.listdir(flow_path) if os.path.isdir(flow_path + d) ]) print('δ = {}'.format(delta)) for composition in ['gdp', 'ma']: print('Composing in {}...'.format(composition)) for iteration in iterations: epsilon = utils.get_epsilon( private, composition,