Пример #1
0
    def test_actnorm(self):
        for test in (returns_correct_shape, is_bijective):
            test(self, flows.ActNorm())

        # Test data-dependent initialization
        inputs = random.uniform(random.PRNGKey(0), (20, 3), minval=-10.0, maxval=10.0)
        input_dim = inputs.shape[1]

        init_fun = flows.Serial(flows.ActNorm())
        params, direct_fun, inverse_fun = init_fun(random.PRNGKey(0), inputs.shape[1:], init_inputs=inputs)
        mapped_inputs, _ = direct_fun(params, inputs)

        self.assertFalse((np.abs(mapped_inputs.mean(0)) > 1e6).any())
        self.assertTrue(np.allclose(np.ones(input_dim), mapped_inputs.std(0)))
Пример #2
0
def learn_flow(X,
               hidden_dim=48,
               num_hidden=2,
               num_unit=5,
               learning_rate=1e-3,
               num_epochs=200,
               batch_size=4000,
               interval=None,
               seed=123):
    """Training with 400K of riz data works well the default args.
    Make sure to remove samples with undetected flux since these otherwise create a delta function.
    """
    # Preprocess input data.
    scaler = preprocessing.StandardScaler().fit(X)
    X_preproc = scaler.transform(X)
    input_dim = X.shape[1]

    # Initialize random numbers.
    rng, flow_rng = jax.random.split(jax.random.PRNGKey(seed))

    # Initialize our flow bijection.
    transform = functools.partial(masked_transform,
                                  hidden_dim=hidden_dim,
                                  num_hidden=num_hidden)
    bijection_init_fun = flows.Serial(
        *(flows.MADE(transform), flows.Reverse()) * num_unit)

    # Create direct and inverse bijection functions.
    rng, bijection_rng = jax.random.split(rng)
    bijection_params, bijection_direct, bijection_inverse = bijection_init_fun(
        bijection_rng, input_dim)

    # Initialize our flow model.
    prior_init_fun = flows.Normal()
    flow_init_fun = flows.Flow(bijection_init_fun, prior_init_fun)
    initial_params, log_pdf, sample = flow_init_fun(flow_rng, input_dim)

    if interval is not None:
        import matplotlib.pyplot as plt
        bins = np.linspace(-0.05, 1.05, 111)

    def loss_fn(params, inputs):
        return -log_pdf(params, inputs).mean()

    @jax.jit
    def step(i, opt_state, inputs):
        params = get_params(opt_state)
        loss_value, gradients = jax.value_and_grad(loss_fn)(params, inputs)
        return opt_update(i, gradients, opt_state), loss_value

    opt_init, opt_update, get_params = optimizers.adam(step_size=learning_rate)
    opt_state = opt_init(initial_params)
    root2 = np.sqrt(2.)

    itercount = itertools.count()
    for epoch in range(num_epochs):
        rng, permute_rng = jax.random.split(rng)
        X_epoch = jax.random.permutation(permute_rng, X_preproc)
        for batch_index in range(0, len(X), batch_size):
            opt_state, loss = step(
                next(itercount), opt_state,
                X_epoch[batch_index:batch_index + batch_size])
        if interval is not None and (epoch + 1) % interval == 0:
            print(f'epoch {epoch + 1} loss {loss:.3f}')
            # Map the input data back through the flow to the prior space.
            epoch_params = get_params(opt_state)
            X_normal, _ = bijection_direct(epoch_params, X_epoch)
            X_uniform = 0.5 * (
                1 + scipy.special.erf(np.array(X_normal, np.float64) / root2))
            for i in range(input_dim):
                plt.hist(X_uniform[:, i], bins, histtype='step')
            plt.show()

    # Return a function that maps samples to a ~uniform distribution on [0,1] ** input_dim.
    # Takes a numpy array as input and returns a numpy array of the same shape.
    final_params = get_params(opt_state)

    def flow_map(Y):
        Y_preproc = scaler.transform(Y)
        Y_normal, _ = bijection_direct(final_params, jnp.array(Y_preproc))
        #return np.array(Y_normal)
        Y_uniform = 0.5 * (
            1 + scipy.special.erf(np.array(Y_normal, np.float64) / root2))
        return Y_uniform.astype(np.float32)

    return flow_map
Пример #3
0
        X, X_test = X_full[idx_train], X_full[idx_test]

        scaler = preprocessing.StandardScaler()
        X = scaler.fit_transform(X)
        X_test = scaler.transform(X_test)

        delta = 1. / (X.shape[0]**1.1)

        print('X: {}'.format(X.shape))
        print('X test: {}'.format(X_test.shape))
        print('Delta: {}'.format(delta))

        # Create flow
        modules = flow_utils.get_modules(flow, num_blocks, normalization,
                                         num_hidden)
        bijection = flows.Serial(*tuple(modules))

        # Remove previous directory to avoid ambiguity
        if log_params and overwrite:
            try:
                shutil.rmtree(output_dir)
            except:
                pass

        # Create experiment directory
        output_dir_tokens = [
            'out', dataset, 'flows', experiment,
            str(fold_iter)
        ]
        output_dir = ''
        for ext in output_dir_tokens:
Пример #4
0
 def test_serial(self):
     for test in (returns_correct_shape, is_bijective):
         test(self, flows.Serial(flows.Shuffle(), flows.Shuffle()))