def main(args): pyro.enable_validation(__debug__) if args.cuda: torch.set_default_tensor_type('torch.cuda.FloatTensor') # Generate synthetic data. pyro.set_rng_seed(args.seed) x_train, t_train, y_train, _ = generate_data(args) # Train. pyro.set_rng_seed(args.seed) pyro.clear_param_store() cevae = CEVAE(feature_dim=args.feature_dim, latent_dim=args.latent_dim, hidden_dim=args.hidden_dim, num_layers=args.num_layers, num_samples=10) cevae.fit(x_train, t_train, y_train, num_epochs=args.num_epochs, batch_size=args.batch_size, learning_rate=args.learning_rate, learning_rate_decay=args.learning_rate_decay, weight_decay=args.weight_decay) # Evaluate. x_test, t_test, y_test, true_ite = generate_data(args) true_ate = true_ite.mean() print("true ATE = {:0.3g}".format(true_ate.item())) naive_ate = y_test[t_test == 1].mean() - y_test[t_test == 0].mean() print("naive ATE = {:0.3g}".format(naive_ate)) if args.jit: cevae = cevae.to_script_module() est_ite = cevae.ite(x_test) est_ate = est_ite.mean() print("estimated ATE = {:0.3g}".format(est_ate.item()))
def test_serialization(jit, feature_dim, outcome_dist): x, t, y = generate_data(num_data=32, feature_dim=feature_dim) if outcome_dist == "exponential": y.clamp_(min=1e-20) cevae = CEVAE(feature_dim, outcome_dist=outcome_dist, num_samples=1000, hidden_dim=32) cevae.fit(x, t, y, num_epochs=4, batch_size=8) pyro.set_rng_seed(0) expected_ite = cevae.ite(x) if jit: traced_cevae = cevae.to_script_module() f = io.BytesIO() torch.jit.save(traced_cevae, f) f.seek(0) loaded_cevae = torch.jit.load(f) else: f = io.BytesIO() with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) torch.save(cevae, f) f.seek(0) loaded_cevae = torch.load(f) pyro.set_rng_seed(0) actual_ite = loaded_cevae.ite(x) assert_close(actual_ite, expected_ite, atol=0.1)