Example #1
0
def main(args):
    pyro.enable_validation(__debug__)
    if args.cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')

    # Generate synthetic data.
    pyro.set_rng_seed(args.seed)
    x_train, t_train, y_train, _ = generate_data(args)

    # Train.
    pyro.set_rng_seed(args.seed)
    pyro.clear_param_store()
    cevae = CEVAE(feature_dim=args.feature_dim,
                  latent_dim=args.latent_dim,
                  hidden_dim=args.hidden_dim,
                  num_layers=args.num_layers,
                  num_samples=10)
    cevae.fit(x_train, t_train, y_train,
              num_epochs=args.num_epochs,
              batch_size=args.batch_size,
              learning_rate=args.learning_rate,
              learning_rate_decay=args.learning_rate_decay,
              weight_decay=args.weight_decay)

    # Evaluate.
    x_test, t_test, y_test, true_ite = generate_data(args)
    true_ate = true_ite.mean()
    print("true ATE = {:0.3g}".format(true_ate.item()))
    naive_ate = y_test[t_test == 1].mean() - y_test[t_test == 0].mean()
    print("naive ATE = {:0.3g}".format(naive_ate))
    if args.jit:
        cevae = cevae.to_script_module()
    est_ite = cevae.ite(x_test)
    est_ate = est_ite.mean()
    print("estimated ATE = {:0.3g}".format(est_ate.item()))
Example #2
0
def test_serialization(jit, feature_dim, outcome_dist):
    x, t, y = generate_data(num_data=32, feature_dim=feature_dim)
    if outcome_dist == "exponential":
        y.clamp_(min=1e-20)
    cevae = CEVAE(feature_dim, outcome_dist=outcome_dist, num_samples=1000, hidden_dim=32)
    cevae.fit(x, t, y, num_epochs=4, batch_size=8)
    pyro.set_rng_seed(0)
    expected_ite = cevae.ite(x)

    if jit:
        traced_cevae = cevae.to_script_module()
        f = io.BytesIO()
        torch.jit.save(traced_cevae, f)
        f.seek(0)
        loaded_cevae = torch.jit.load(f)
    else:
        f = io.BytesIO()
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=UserWarning)
            torch.save(cevae, f)
        f.seek(0)
        loaded_cevae = torch.load(f)

    pyro.set_rng_seed(0)
    actual_ite = loaded_cevae.ite(x)
    assert_close(actual_ite, expected_ite, atol=0.1)