def joint_step(i, opt_state, rng, fused_lhs): fused_ps, our_ps = get_params_upper(opt_state) fused_batches = fused_sample_our_uniform(batch_size_inner, num_boxes_per_step, dim, rng, tol=tol) g = grad_joint_mean_elbo(fused_ps, our_ps, fused_batches, fused_lhs) rng, _ = random.split(rng, 2) return opt_update_upper(i, g, opt_state), rng
def lower_step(i, opt_state, fused_params, rng, fused_lhs): our_ps = get_params_lower(opt_state) fused_batches = fused_sample_our_uniform(batch_size_inner, num_boxes_per_step, dim, rng, tol=tol) g = grad_lower(fused_params, our_ps, fused_batches, fused_lhs)[0] rng, _ = random.split(rng, 2) return opt_update_lower(i, g, opt_state), rng
def upper_average_compute_step(i, opt_state, our_ps, rng, fused_lhs): """Separate function so we can use a lower learning rate""" fused_ps = get_params_compute_upper(opt_state) fused_batches = fused_sample_our_uniform(batch_size_inner, compute_objective_num_box, dim, rng, tol=tol) g = grad_upper_mean_elbo(fused_ps, our_ps, fused_batches, fused_lhs)[0] rng, _ = random.split(rng, 2) return opt_update_compute_upper(i, g, opt_state), rng, g
def compute_objective(fixed_ps, num_iters, num_boxes, num_cells, rng): deep_ps, deep_cs, lhs = init_double_deep_collection( dim, n_temp, num_boxes, num_cells, rng) fused_ps = fuse_params(deep_ps) fused_lhs = fuse_params(lhs) opt_state = opt_init_compute_upper(fused_ps) best_fused_elbos = onp.ones(num_boxes) * -np.inf best_unfused_ps = deep_ps for i in range(num_iters): opt_state, rng = upper_compute_step(i, opt_state, fixed_ps, rng, fused_lhs) if i % save_frequency == 0: fused_ps = get_params_compute_upper(opt_state) fused_batches = fused_sample_our_uniform( 5 * compute_objective_batch_size, num_boxes, dim, rng, tol=tol) new_objective_losses = vmapped_nelbo(fused_ps, fixed_ps, fused_batches, fused_lhs) # print(f"Test Loss is {-new_objective_losses}") # print(f"Test Total Loss is {logsumexp(-new_objective_losses) - np.log(num_boxes)}") for k in range(num_boxes): if -new_objective_losses[k] > best_fused_elbos[k]: best_fused_elbos[k] = -new_objective_losses[k] best_unfused_ps[k] = unfuse_params(fused_ps)[k] fused_ps = fuse_params(best_unfused_ps) fused_batches = fused_sample_our_uniform(5 * compute_objective_batch_size, num_boxes, dim, rng, tol=tol) new_objective_loss = batch_loss(fused_ps, fixed_ps, fused_batches, fused_lhs) return -new_objective_loss
def joint_procedure(rng, in_ps): elbo_estimates = [] if joint_method == "pure_joint": opt_state_lower = opt_init_lower(in_ps) burn_in_counter = 0 for j in range(num_big_steps): burn_in_counter += 1 lhs, rng = choose_cells(dim, num_boxes_per_step, num_cells, rng) fused_ps = fuse_params(static_deep_ps_collection) fused_lhs = fuse_params(lhs) if joint_method == "pure_joint": opt_state_upper = opt_init_upper(fused_ps) else: # Use a joint training method opt_state_upper = opt_init_upper((fused_ps, in_ps)) best_elbo = -np.inf best_fused_ps = None for k in range(num_inner_iters): if joint_method == "pure_joint": opt_state_upper, rng = upper_step(k, opt_state_upper, in_ps, rng, fused_lhs) elif joint_method == "rectified": opt_state_upper, rng = upper_average_rectfied_step( k, opt_state_upper, rng, fused_lhs) else: assert joint_method == "joint_naive" opt_state_upper, rng = joint_step(k, opt_state_upper, rng, fused_lhs) if k % save_frequency == 0: fused_batches = fused_sample_our_uniform( compute_objective_batch_size, num_boxes_per_step, dim, rng, tol=tol, ) if joint_method == "pure_joint": fused_ps = get_params_upper(opt_state_upper) else: fused_ps, in_ps = get_params_upper(opt_state_upper) new_objective_loss = batch_loss(fused_ps, in_ps, fused_batches, fused_lhs) # print("ELBO running estimate: {}".format(-new_objective_loss)) if -new_objective_loss > best_elbo: best_elbo = -new_objective_loss best_fused_ps = fused_ps print("At iter {}".format(j)) fused_ps = best_fused_ps if joint_method == "pure_joint": opt_state_lower, rng = lower_step(j, opt_state_lower, fused_ps, rng, fused_lhs) in_ps = get_params_lower(opt_state_lower) elbo_estimate = compute_objective( in_ps, compute_objective_inner_iters, compute_objective_num_box, num_cells, rng, ) if burn_in_counter > burn_in: elbo_estimates.append(elbo_estimate) print(elbo_estimate) # with open(filestring, 'rb') as f: # new_ps = pickle.load(f) # plot_save_2_lines_density(in_ps, cs, new_ps, cs, input_fun) print("ELBO real estimate: {}".format(elbo_estimate)) elbo_estimates = np.concatenate( [x.reshape((1, )) for x in elbo_estimates if not np.isnan(x)]) return elbo_estimates