def main(_argv): transition_alphas = torch.tensor([[10., 90.], [90., 10.]]) emission_alphas = torch.tensor([[[30., 20., 5.]], [[5., 10., 100.]]]) lengths = torch.randint(10, 30, (10000,)) trace = poutine.trace(model).get_trace(transition_alphas, emission_alphas, lengths) obs_sequences = [site['value'] for name, site in trace.nodes.items() if name.startswith("element_")] obs_sequences = torch.stack(obs_sequences, dim=-2) guide = AutoDelta(poutine.block(model, hide_fn=lambda site: site['name'].startswith('state')), init_loc_fn=init_to_sample) svi = SVI(model, guide, Adam(dict(lr=0.1)), JitTraceEnum_ELBO()) total = 1000 with tqdm.trange(total) as t: for i in t: loss = svi.step(0.5 * torch.ones((2, 2), dtype=torch.float), 0.3 * torch.ones((2, 1, 3), dtype=torch.float), lengths, obs_sequences) t.set_description_str(f"SVI ({i}/{total}): {loss}") median = guide.median() print("Transition probs: ", median['transition_probs'].detach().numpy()) print("Emission probs: ", median['emission_probs'].squeeze().detach().numpy())
def main(args): # setup hyperparameters for the model hypers = { 'expected_sparsity': max(1.0, args.num_dimensions / 10), 'alpha1': 3.0, 'beta1': 1.0, 'alpha2': 3.0, 'beta2': 1.0, 'alpha3': 1.0, 'c': 1.0 } P = args.num_dimensions S = args.active_dimensions Q = args.quadratic_dimensions # generate artificial dataset X, Y, expected_thetas, expected_quad_dims = get_data(N=args.num_data, P=P, S=S, Q=Q, sigma_obs=args.sigma) loss_fn = Trace_ELBO().differentiable_loss # We initialize the AutoDelta guide (for MAP estimation) with args.num_trials many # initial parameters sampled from the vicinity of the median of the prior distribution # and then continue optimizing with the best performing initialization. init_losses = [] for restart in range(args.num_restarts): pyro.clear_param_store() pyro.set_rng_seed(restart) guide = AutoDelta(model, init_loc_fn=init_loc_fn) with torch.no_grad(): init_losses.append(loss_fn(model, guide, X, Y, hypers).item()) pyro.set_rng_seed(np.argmin(init_losses)) pyro.clear_param_store() guide = AutoDelta(model, init_loc_fn=init_loc_fn) # Instead of using pyro.infer.SVI and pyro.optim we instead construct our own PyTorch # optimizer and take charge of gradient-based optimization ourselves. with poutine.block(), poutine.trace(param_only=True) as param_capture: guide(X, Y, hypers) params = list( [pyro.param(name).unconstrained() for name in param_capture.trace]) adam = Adam(params, lr=args.lr) report_frequency = 50 print("Beginning MAP optimization...") # the optimization loop for step in range(args.num_steps): loss = loss_fn(model, guide, X, Y, hypers) / args.num_data loss.backward() adam.step() adam.zero_grad() # we manually reduce the learning rate according to this schedule if step in [100, 300, 700, 900]: adam.param_groups[0]['lr'] *= 0.2 if step % report_frequency == 0 or step == args.num_steps - 1: print("[step %04d] loss: %.5f" % (step, loss)) print("Expected singleton thetas:\n", expected_thetas.data.numpy()) # we do the final computation using double precision median = guide.median() # == mode for MAP inference active_dims, active_quad_dims = \ compute_posterior_stats(X.double(), Y.double(), median['msq'].double(), median['lambda'].double(), median['eta1'].double(), median['xisq'].double(), torch.tensor(hypers['c']).double(), median['sigma'].double()) expected_active_dims = np.arange(S).tolist() tp_singletons = len(set(active_dims) & set(expected_active_dims)) fp_singletons = len(set(active_dims) - set(expected_active_dims)) fn_singletons = len(set(expected_active_dims) - set(active_dims)) singleton_stats = (tp_singletons, fp_singletons, fn_singletons) tp_quads = len(set(active_quad_dims) & set(expected_quad_dims)) fp_quads = len(set(active_quad_dims) - set(expected_quad_dims)) fn_quads = len(set(expected_quad_dims) - set(active_quad_dims)) quad_stats = (tp_quads, fp_quads, fn_quads) # We report how well we did, i.e. did we recover the sparse set of coefficients # that we expected for our artificial dataset? print("[SUMMARY STATS]") print("Singletons (true positive, false positive, false negative): " + "(%d, %d, %d)" % singleton_stats) print("Quadratic (true positive, false positive, false negative): " + "(%d, %d, %d)" % quad_stats)