Example #1
0
def main(_argv):
    transition_alphas = torch.tensor([[10., 90.],
                                      [90., 10.]])
    emission_alphas = torch.tensor([[[30., 20., 5.]],
                                    [[5., 10., 100.]]])
    lengths = torch.randint(10, 30, (10000,))
    trace = poutine.trace(model).get_trace(transition_alphas, emission_alphas, lengths)
    obs_sequences = [site['value'] for name, site in
                     trace.nodes.items() if name.startswith("element_")]
    obs_sequences = torch.stack(obs_sequences, dim=-2)
    guide = AutoDelta(poutine.block(model, hide_fn=lambda site: site['name'].startswith('state')),
                      init_loc_fn=init_to_sample)
    svi = SVI(model, guide, Adam(dict(lr=0.1)), JitTraceEnum_ELBO())
    total = 1000
    with tqdm.trange(total) as t:
        for i in t:
            loss = svi.step(0.5 * torch.ones((2, 2), dtype=torch.float),
                            0.3 * torch.ones((2, 1, 3), dtype=torch.float),
                            lengths, obs_sequences)
            t.set_description_str(f"SVI ({i}/{total}): {loss}")
    median = guide.median()
    print("Transition probs: ", median['transition_probs'].detach().numpy())
    print("Emission probs: ", median['emission_probs'].squeeze().detach().numpy())
Example #2
0
def main(args):
    # setup hyperparameters for the model
    hypers = {
        'expected_sparsity': max(1.0, args.num_dimensions / 10),
        'alpha1': 3.0,
        'beta1': 1.0,
        'alpha2': 3.0,
        'beta2': 1.0,
        'alpha3': 1.0,
        'c': 1.0
    }

    P = args.num_dimensions
    S = args.active_dimensions
    Q = args.quadratic_dimensions

    # generate artificial dataset
    X, Y, expected_thetas, expected_quad_dims = get_data(N=args.num_data,
                                                         P=P,
                                                         S=S,
                                                         Q=Q,
                                                         sigma_obs=args.sigma)

    loss_fn = Trace_ELBO().differentiable_loss

    # We initialize the AutoDelta guide (for MAP estimation) with args.num_trials many
    # initial parameters sampled from the vicinity of the median of the prior distribution
    # and then continue optimizing with the best performing initialization.
    init_losses = []
    for restart in range(args.num_restarts):
        pyro.clear_param_store()
        pyro.set_rng_seed(restart)
        guide = AutoDelta(model, init_loc_fn=init_loc_fn)
        with torch.no_grad():
            init_losses.append(loss_fn(model, guide, X, Y, hypers).item())

    pyro.set_rng_seed(np.argmin(init_losses))
    pyro.clear_param_store()
    guide = AutoDelta(model, init_loc_fn=init_loc_fn)

    # Instead of using pyro.infer.SVI and pyro.optim we instead construct our own PyTorch
    # optimizer and take charge of gradient-based optimization ourselves.
    with poutine.block(), poutine.trace(param_only=True) as param_capture:
        guide(X, Y, hypers)
    params = list(
        [pyro.param(name).unconstrained() for name in param_capture.trace])
    adam = Adam(params, lr=args.lr)

    report_frequency = 50
    print("Beginning MAP optimization...")

    # the optimization loop
    for step in range(args.num_steps):
        loss = loss_fn(model, guide, X, Y, hypers) / args.num_data
        loss.backward()
        adam.step()
        adam.zero_grad()

        # we manually reduce the learning rate according to this schedule
        if step in [100, 300, 700, 900]:
            adam.param_groups[0]['lr'] *= 0.2

        if step % report_frequency == 0 or step == args.num_steps - 1:
            print("[step %04d]  loss: %.5f" % (step, loss))

    print("Expected singleton thetas:\n", expected_thetas.data.numpy())

    # we do the final computation using double precision
    median = guide.median()  # == mode for MAP inference
    active_dims, active_quad_dims = \
        compute_posterior_stats(X.double(), Y.double(), median['msq'].double(),
                                median['lambda'].double(), median['eta1'].double(),
                                median['xisq'].double(), torch.tensor(hypers['c']).double(),
                                median['sigma'].double())

    expected_active_dims = np.arange(S).tolist()

    tp_singletons = len(set(active_dims) & set(expected_active_dims))
    fp_singletons = len(set(active_dims) - set(expected_active_dims))
    fn_singletons = len(set(expected_active_dims) - set(active_dims))
    singleton_stats = (tp_singletons, fp_singletons, fn_singletons)

    tp_quads = len(set(active_quad_dims) & set(expected_quad_dims))
    fp_quads = len(set(active_quad_dims) - set(expected_quad_dims))
    fn_quads = len(set(expected_quad_dims) - set(active_quad_dims))
    quad_stats = (tp_quads, fp_quads, fn_quads)

    # We report how well we did, i.e. did we recover the sparse set of coefficients
    # that we expected for our artificial dataset?
    print("[SUMMARY STATS]")
    print("Singletons (true positive, false positive, false negative): " +
          "(%d, %d, %d)" % singleton_stats)
    print("Quadratic  (true positive, false positive, false negative): " +
          "(%d, %d, %d)" % quad_stats)