Esempio n. 1
0
 def joint_step(i, opt_state, rng, fused_lhs):
     fused_ps, our_ps = get_params_upper(opt_state)
     fused_batches = fused_sample_our_uniform(batch_size_inner,
                                              num_boxes_per_step,
                                              dim,
                                              rng,
                                              tol=tol)
     g = grad_joint_mean_elbo(fused_ps, our_ps, fused_batches, fused_lhs)
     rng, _ = random.split(rng, 2)
     return opt_update_upper(i, g, opt_state), rng
Esempio n. 2
0
 def lower_step(i, opt_state, fused_params, rng, fused_lhs):
     our_ps = get_params_lower(opt_state)
     fused_batches = fused_sample_our_uniform(batch_size_inner,
                                              num_boxes_per_step,
                                              dim,
                                              rng,
                                              tol=tol)
     g = grad_lower(fused_params, our_ps, fused_batches, fused_lhs)[0]
     rng, _ = random.split(rng, 2)
     return opt_update_lower(i, g, opt_state), rng
Esempio n. 3
0
 def upper_average_compute_step(i, opt_state, our_ps, rng, fused_lhs):
     """Separate function so we can use a lower learning rate"""
     fused_ps = get_params_compute_upper(opt_state)
     fused_batches = fused_sample_our_uniform(batch_size_inner,
                                              compute_objective_num_box,
                                              dim,
                                              rng,
                                              tol=tol)
     g = grad_upper_mean_elbo(fused_ps, our_ps, fused_batches, fused_lhs)[0]
     rng, _ = random.split(rng, 2)
     return opt_update_compute_upper(i, g, opt_state), rng, g
Esempio n. 4
0
 def compute_objective(fixed_ps, num_iters, num_boxes, num_cells, rng):
     deep_ps, deep_cs, lhs = init_double_deep_collection(
         dim, n_temp, num_boxes, num_cells, rng)
     fused_ps = fuse_params(deep_ps)
     fused_lhs = fuse_params(lhs)
     opt_state = opt_init_compute_upper(fused_ps)
     best_fused_elbos = onp.ones(num_boxes) * -np.inf
     best_unfused_ps = deep_ps
     for i in range(num_iters):
         opt_state, rng = upper_compute_step(i, opt_state, fixed_ps, rng,
                                             fused_lhs)
         if i % save_frequency == 0:
             fused_ps = get_params_compute_upper(opt_state)
             fused_batches = fused_sample_our_uniform(
                 5 * compute_objective_batch_size,
                 num_boxes,
                 dim,
                 rng,
                 tol=tol)
             new_objective_losses = vmapped_nelbo(fused_ps, fixed_ps,
                                                  fused_batches, fused_lhs)
             # print(f"Test Loss is {-new_objective_losses}")
             # print(f"Test Total Loss is {logsumexp(-new_objective_losses) - np.log(num_boxes)}")
             for k in range(num_boxes):
                 if -new_objective_losses[k] > best_fused_elbos[k]:
                     best_fused_elbos[k] = -new_objective_losses[k]
                     best_unfused_ps[k] = unfuse_params(fused_ps)[k]
     fused_ps = fuse_params(best_unfused_ps)
     fused_batches = fused_sample_our_uniform(5 *
                                              compute_objective_batch_size,
                                              num_boxes,
                                              dim,
                                              rng,
                                              tol=tol)
     new_objective_loss = batch_loss(fused_ps, fixed_ps, fused_batches,
                                     fused_lhs)
     return -new_objective_loss
Esempio n. 5
0
    def joint_procedure(rng, in_ps):
        elbo_estimates = []
        if joint_method == "pure_joint":
            opt_state_lower = opt_init_lower(in_ps)
        burn_in_counter = 0
        for j in range(num_big_steps):
            burn_in_counter += 1
            lhs, rng = choose_cells(dim, num_boxes_per_step, num_cells, rng)
            fused_ps = fuse_params(static_deep_ps_collection)
            fused_lhs = fuse_params(lhs)

            if joint_method == "pure_joint":
                opt_state_upper = opt_init_upper(fused_ps)
            else:
                # Use a joint training method
                opt_state_upper = opt_init_upper((fused_ps, in_ps))
            best_elbo = -np.inf
            best_fused_ps = None
            for k in range(num_inner_iters):
                if joint_method == "pure_joint":
                    opt_state_upper, rng = upper_step(k, opt_state_upper,
                                                      in_ps, rng, fused_lhs)
                elif joint_method == "rectified":
                    opt_state_upper, rng = upper_average_rectfied_step(
                        k, opt_state_upper, rng, fused_lhs)
                else:
                    assert joint_method == "joint_naive"
                    opt_state_upper, rng = joint_step(k, opt_state_upper, rng,
                                                      fused_lhs)
                if k % save_frequency == 0:
                    fused_batches = fused_sample_our_uniform(
                        compute_objective_batch_size,
                        num_boxes_per_step,
                        dim,
                        rng,
                        tol=tol,
                    )
                    if joint_method == "pure_joint":
                        fused_ps = get_params_upper(opt_state_upper)
                    else:
                        fused_ps, in_ps = get_params_upper(opt_state_upper)
                    new_objective_loss = batch_loss(fused_ps, in_ps,
                                                    fused_batches, fused_lhs)
                    # print("ELBO running estimate: {}".format(-new_objective_loss))

                    if -new_objective_loss > best_elbo:
                        best_elbo = -new_objective_loss
                        best_fused_ps = fused_ps

            print("At iter {}".format(j))
            fused_ps = best_fused_ps
            if joint_method == "pure_joint":
                opt_state_lower, rng = lower_step(j, opt_state_lower, fused_ps,
                                                  rng, fused_lhs)
                in_ps = get_params_lower(opt_state_lower)
            elbo_estimate = compute_objective(
                in_ps,
                compute_objective_inner_iters,
                compute_objective_num_box,
                num_cells,
                rng,
            )
            if burn_in_counter > burn_in:
                elbo_estimates.append(elbo_estimate)
            print(elbo_estimate)

            # with open(filestring, 'rb') as f:
            # new_ps = pickle.load(f)
            # plot_save_2_lines_density(in_ps, cs, new_ps, cs, input_fun)

            print("ELBO real estimate: {}".format(elbo_estimate))
        elbo_estimates = np.concatenate(
            [x.reshape((1, )) for x in elbo_estimates if not np.isnan(x)])
        return elbo_estimates