Esempio n. 1
0
def run_vip_hmc_discrete(model_config,
                         parameterisation,
                         num_samples=2000,
                         burnin=1000,
                         num_leapfrog_steps=4,
                         num_adaptation_steps=500,
                         num_optimization_steps=2000):

    tf.reset_default_graph()

    (_, insightful_parametrisation,
     _) = ed_transforms.make_learnable_parametrisation(
         learnable_parameters=parameterisation)

    results = run_parametrised_hmc(
        model_config=model_config,
        interceptor=insightful_parametrisation,
        num_samples=num_samples,
        burnin=burnin,
        num_leapfrog_steps=num_leapfrog_steps,
        num_adaptation_steps=num_adaptation_steps,
        num_optimization_steps=num_optimization_steps)

    results['parameterisation'] = parameterisation

    return results
Esempio n. 2
0
def make_cvip_graph(model_config,
                    parameterisation_type='exp',
                    tied_pparams=False):
    """
                Constructs the cVIP graph of the given model.
                Resets the default TF graph.
        """

    tf.reset_default_graph()

    results = collections.OrderedDict()

    (learnable_parameters, learnable_parametrisation,
     _) = ed_transforms.make_learnable_parametrisation(
         tau=1.,
         parameterisation_type=parameterisation_type,
         tied_pparams=tied_pparams)

    def model_vip(*params):
        with ed.interception(learnable_parametrisation):
            return model_config.model(*params)

    if model_config.bijectors_fn is not None:
        model_vip = ed_transforms.transform_with_bijectors(
            model_vip, model_config.bijectors_fn)

    log_joint_vip = ed.make_log_joint_fn(model_vip)  # log_joint_fn

    with ed.tape() as model_tape:
        _ = model_vip(*model_config.model_args)

    target_vip_kwargs = {}
    for param in model_tape.keys():
        if param in model_config.observed_data.keys():
            target_vip_kwargs[param] = model_config.observed_data[param]

    def target_vip(*param_args):  # latent_log_joint_fn
        i = 0
        for param in model_tape.keys():
            if param not in model_config.observed_data.keys():
                target_vip_kwargs[param] = param_args[i]
                i = i + 1
        return log_joint_vip(*model_config.model_args, **target_vip_kwargs)

    #full_kwargs = collections.OrderedDict(model_config.observed_data.items())
    #full_kwargs['parameterisation'] = collections.OrderedDict()
    #for k in learnable_parameters.keys():
    #	full_kwargs['parameterisation'][k] = learnable_parameters[k]

    elbo, variational_parameters = util.get_mean_field_elbo(
        model_vip,
        target_vip,
        num_mc_samples=FLAGS.num_mc_samples,
        model_args=model_config.model_args,
        model_obs_kwargs=model_config.observed_data,
        vi_kwargs={'parameterisation':
                   learnable_parameters})  #vi_kwargs=full_kwargs

    return target_vip, model_vip, elbo, variational_parameters, learnable_parameters
Esempio n. 3
0
def make_dvip_graph(model_config, reparam, parameterisation_type='exp'):
    """
                Constructs the dVIP graph of the given model, where `reparam` is
                a cVIP
                reparameterisation obtained previously.
                Resets the default TF graph.
        """

    tf.reset_default_graph()

    results = collections.OrderedDict()

    _, insightful_parametrisation, _ = ed_transforms.make_learnable_parametrisation(
        learnable_parameters=reparam,
        parameterisation_type=parameterisation_type)

    def model_vip(*params):
        with ed.interception(insightful_parametrisation):
            return model_config.model(*params)

    if model_config.bijectors_fn is not None:
        model_vip = ed_transforms.transform_with_bijectors(
            model_vip, model_config.bijectors_fn)

    log_joint_vip = ed.make_log_joint_fn(model_vip)  # log_joint_fn

    with ed.tape() as model_tape:
        _ = model_vip(*model_config.model_args)

    target_vip_kwargs = {}
    for param in model_tape.keys():
        if param in model_config.observed_data.keys():
            target_vip_kwargs[param] = model_config.observed_data[param]

    def target_vip(*param_args):  # latent_log_joint_fn
        i = 0
        for param in model_tape.keys():
            if param not in model_config.observed_data.keys():
                target_vip_kwargs[param] = param_args[i]
                i = i + 1
        return log_joint_vip(*model_config.model_args, **target_vip_kwargs)

    elbo, variational_parameters = util.get_mean_field_elbo(
        model_vip,
        target_vip,
        num_mc_samples=FLAGS.num_mc_samples,
        model_args=model_config.model_args,
        model_obs_kwargs=model_config.observed_data,
        vi_kwargs={'parameterisation': reparam})

    return target_vip, model_vip, elbo, variational_parameters, None
Esempio n. 4
0
    def make_to_centered(**centering_kwargs):
        (_, parametrisation, _) = ed_transforms.make_learnable_parametrisation(
            learnable_parameters=centering_kwargs)

        def to_centered(uncentered_state):
            set_values = ed_transforms.make_value_setter(*uncentered_state)
            with ed.interception(set_values):
                with ed.interception(parametrisation):
                    with ed.tape() as centered_tape:
                        model(*model_args)

            param_vals = [
                tf.identity(v) for k, v in centered_tape.items()
                if k not in observed_data.keys()
            ]
            return param_vals
            # [tf.identity(v) for v in list(centered_tape.values())[:-1]]

        return to_centered
Esempio n. 5
0
def run_vip_hmc_continuous(model_config,
                           num_samples=2000,
                           burnin=1000,
                           use_iaf_posterior=False,
                           num_leapfrog_steps=4,
                           num_adaptation_steps=500,
                           num_optimization_steps=2000,
                           num_mc_samples=32,
                           tau=1.,
                           do_sample=True,
                           description='',
                           experiments_dir=''):

    tf.reset_default_graph()

    if use_iaf_posterior:
        # IAF posterior doesn't give us stddevs for step sizes for HMC (we could
        # extract them by sampling but I haven't implemented that), and we mostly
        # care about it for ELBOs anyway.
        do_sample = False

    init_val_loc = tf.placeholder('float', shape=())
    init_val_scale = tf.placeholder('float', shape=())

    (learnable_parameters, learnable_parametrisation,
     _) = ed_transforms.make_learnable_parametrisation(
         init_val_loc=init_val_loc, init_val_scale=init_val_scale, tau=tau)

    def model_vip(*params):
        with ed.interception(learnable_parametrisation):
            return model_config.model(*params)

    log_joint_vip = ed.make_log_joint_fn(model_vip)

    with ed.tape() as model_tape:
        _ = model_vip(*model_config.model_args)

    param_shapes = collections.OrderedDict()
    target_vip_kwargs = {}
    for param in model_tape.keys():
        if param not in model_config.observed_data.keys():
            param_shapes[param] = model_tape[param].shape
        else:
            target_vip_kwargs[param] = model_config.observed_data[param]

    def target_vip(*param_args):
        i = 0
        for param in model_tape.keys():
            if param not in model_config.observed_data.keys():
                target_vip_kwargs[param] = param_args[i]
                i = i + 1
        return log_joint_vip(*model_config.model_args, **target_vip_kwargs)

    full_kwargs = collections.OrderedDict(model_config.observed_data.items())
    full_kwargs['parameterisation'] = collections.OrderedDict()
    for k in learnable_parameters.keys():
        full_kwargs['parameterisation'][k] = learnable_parameters[k]

    if use_iaf_posterior:
        elbo = util.get_iaf_elbo(target_vip,
                                 num_mc_samples=num_mc_samples,
                                 param_shapes=param_shapes)
        variational_parameters = {}
    else:
        elbo, variational_parameters = util.get_mean_field_elbo(
            model_vip,
            target_vip,
            num_mc_samples=num_mc_samples,
            model_args=model_config.model_args,
            vi_kwargs=full_kwargs)
        vip_step_size_approx = util.get_approximate_step_size(
            variational_parameters, num_leapfrog_steps)

    ##############################################################################

    best_elbo = None
    model_dir = os.path.join(
        experiments_dir, str(description + '_' + model_config.model.__name__))

    if not tf.gfile.Exists(model_dir):
        tf.gfile.MakeDirs(model_dir)

    saver = tf.train.Saver()
    dir_save = os.path.join(model_dir, 'saved_params_{}'.format(gen_id()))

    if not tf.gfile.Exists(dir_save):
        tf.gfile.MakeDirs(dir_save)

    best_lr = None
    best_init_loc = None
    best_init_scale = None

    learning_rate_ph = tf.placeholder(shape=[], dtype=tf.float32)
    learning_rate = tf.Variable(learning_rate_ph, trainable=False)
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    train = optimizer.minimize(-elbo)
    init = tf.global_variables_initializer()

    learning_rates = [0.003, 0.01, 0.01, 0.1, 0.003, 0.01]
    if use_iaf_posterior:
        learning_rates = [3e-5, 1e-4, 3e-4, 1e-4]

    start_time = time.time()
    for learning_rate_val in learning_rates:
        for init_loc in [0.]:  #, 10., -10.]:
            for init_scale in [init_loc]:

                timeline = []

                with tf.Session() as sess:

                    init.run(
                        feed_dict={
                            init_val_loc: init_loc,
                            init_val_scale: init_scale,
                            learning_rate_ph: learning_rate_val
                        })

                    this_timeline = []
                    for i in range(num_optimization_steps):
                        _, e = sess.run([train, elbo])

                        if np.isnan(e):
                            util.print(
                                'got NaN in ELBO optimization, stopping...')
                            break

                        this_timeline.append(e)

                    this_elbo = np.mean(this_timeline[-100:])
                    info_str = ('finished cVIP optimization with elbo {} vs '
                                'best ELBO {}'.format(this_elbo, best_elbo))
                    util.print(info_str)
                    if best_elbo is None or best_elbo < this_elbo:
                        best_elbo = this_elbo
                        timeline = this_timeline

                        vals = sess.run(list(learnable_parameters.values()))
                        learned_reparam = collections.OrderedDict(
                            zip(learnable_parameters.keys(), vals))
                        vals = sess.run(list(variational_parameters.values()))
                        learned_variational_params = collections.OrderedDict(
                            zip(variational_parameters.keys(), vals))

                        util.print('learned params {}'.format(learned_reparam))
                        util.print('learned variational params {}'.format(
                            learned_variational_params))

                        _ = saver.save(sess, dir_save)
                        best_lr = learning_rate
                        best_init_loc = init_loc
                        best_init_scale = init_scale

    vi_time = time.time() - start_time

    util.print('BEST: LR={}, init={}, {}'.format(best_lr, best_init_loc,
                                                 best_init_scale))
    util.print('ELBO: {}'.format(best_elbo))

    to_centered = model_config.make_to_centered(**learned_reparam)

    results = collections.OrderedDict()
    results['elbo'] = best_elbo

    with tf.Session() as sess:

        saver.restore(sess, dir_save)
        results['vp'] = learned_variational_params

        if do_sample:

            vip_step_size_init = sess.run(vip_step_size_approx)

            vip_step_size = [
                tf.get_variable(
                    name='step_size_vip' + str(i),
                    initializer=np.array(vip_step_size_init[i],
                                         dtype=np.float32),
                    use_resource=True,  # For TFE compatibility.
                    trainable=False) for i in range(len(vip_step_size_init))
            ]

            kernel_vip = mcmc.HamiltonianMonteCarlo(
                target_log_prob_fn=target_vip,
                step_size=vip_step_size,
                num_leapfrog_steps=num_leapfrog_steps,
                step_size_update_fn=mcmc.make_simple_step_size_update_policy(
                    num_adaptation_steps=num_adaptation_steps,
                    target_rate=0.85))

            states, kernel_results_vip = mcmc.sample_chain(
                num_results=num_samples,
                num_burnin_steps=burnin,
                current_state=[
                    tf.zeros(param_shapes[param])
                    for param in param_shapes.keys()
                ],
                kernel=kernel_vip,
                num_steps_between_results=1)

            states_vip = transform_mcmc_states(states, to_centered)

            init_again = tf.global_variables_initializer()
            init_again.run(
                feed_dict={
                    init_val_loc: best_init_loc,
                    init_val_scale: best_init_scale,
                    learning_rate_ph: 1.0
                })  # learning rate doesn't matter for HMC.

            ess_vip = tfp.mcmc.effective_sample_size(states_vip)

            start_time = time.time()
            samples, is_accepted, ess, ss_vip, log_accept_ratio = sess.run(
                (states_vip, kernel_results_vip.is_accepted, ess_vip,
                 kernel_results_vip.extra.step_size_assign,
                 kernel_results_vip.log_accept_ratio))

            sampling_time = time.time() - start_time

            results['samples'] = collections.OrderedDict()
            results['is_accepted'] = is_accepted
            results['acceptance_rate'] = np.sum(is_accepted) * 100. / float(
                num_samples)
            results['ess'] = ess
            results['sampling_time'] = sampling_time
            results['log_accept_ratio'] = log_accept_ratio
            results['step_size'] = [s[0] for s in ss_vip]

            i = 0
            for param in param_shapes.keys():
                results['samples'][param] = samples[i]
                i = i + 1

        # end if

        results['parameterisation'] = collections.OrderedDict()

        i = 0
        for param in param_shapes.keys():
            name_a = param[:-5] + 'a'
            name_b = param[:-5] + 'b'
            try:
                results['parameterisation'][name_a] = learned_reparam[name_a]
                results['parameterisation'][name_b] = learned_reparam[name_b]
            except KeyError:
                continue
            i = i + 1

        results['elbo_timeline'] = timeline
        results['vi_time'] = vi_time

        results['init_pos'] = best_init_loc

        return results