Exemple #1
0
def main(**kwargs):

    args = argparse.Namespace(**kwargs)

    if 'save' in args:
        if os.path.exists(args.save):
            raise RuntimeError('Output file "{}" already exists.'.format(
                args.save))

    if args.seed is not None:
        pyro.set_rng_seed(args.seed)

    X, true_counts = load_data()
    X_size = X.size(0)
    if args.cuda:
        X = X.cuda()

    # Build a function to compute z_pres prior probabilities.
    if args.z_pres_prior_raw:

        def base_z_pres_prior_p(t):
            return args.z_pres_prior
    else:
        base_z_pres_prior_p = make_prior(args.z_pres_prior)

    # Wrap with logic to apply any annealing.
    def z_pres_prior_p(opt_step, time_step):
        p = base_z_pres_prior_p(time_step)
        if args.anneal_prior == 'none':
            return p
        else:
            decay = dict(lin=lin_decay, exp=exp_decay)[args.anneal_prior]
            return decay(p, args.anneal_prior_to, args.anneal_prior_begin,
                         args.anneal_prior_duration, opt_step)

    model_arg_keys = [
        'window_size', 'rnn_hidden_size', 'decoder_output_bias',
        'decoder_output_use_sigmoid', 'baseline_scalar', 'encoder_net',
        'decoder_net', 'predict_net', 'embed_net', 'bl_predict_net',
        'non_linearity', 'pos_prior_mean', 'pos_prior_sd', 'scale_prior_mean',
        'scale_prior_sd'
    ]
    model_args = {
        key: getattr(args, key)
        for key in model_arg_keys if key in args
    }
    air = AIR(num_steps=args.model_steps,
              x_size=50,
              use_masking=not args.no_masking,
              use_baselines=not args.no_baselines,
              z_what_size=args.encoder_latent_size,
              use_cuda=args.cuda,
              **model_args)

    if args.verbose:
        print(air)
        print(args)

    if 'load' in args:
        print('Loading parameters...')
        air.load_state_dict(torch.load(args.load))

    vis = visdom.Visdom(env=args.visdom_env)
    # Viz sample from prior.
    if args.viz:
        z, x = air.prior(5, z_pres_prior_p=partial(z_pres_prior_p, 0))
        vis.images(draw_many(x, tensor_to_objs(latents_to_tensor(z))))

    def per_param_optim_args(module_name, param_name, tags):
        lr = args.baseline_learning_rate if 'baseline' in tags else args.learning_rate
        return {'lr': lr}

    svi = SVI(air.model,
              air.guide,
              optim.Adam(per_param_optim_args),
              loss='ELBO',
              trace_graph=True)

    # Do inference.
    t0 = time.time()
    examples_to_viz = X[5:10]

    for i in range(1, args.num_steps + 1):

        loss = svi.step(X,
                        args.batch_size,
                        z_pres_prior_p=partial(z_pres_prior_p, i))

        if args.progress_every > 0 and i % args.progress_every == 0:
            print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format(
                i, (i * args.batch_size) / X_size, (time.time() - t0) / 3600,
                loss / X_size))

        if args.viz and i % args.viz_every == 0:
            trace = poutine.trace(air.guide).get_trace(examples_to_viz, None)
            z, recons = poutine.replay(air.prior,
                                       trace)(examples_to_viz.size(0))
            z_wheres = tensor_to_objs(latents_to_tensor(z))

            # Show data with inferred objection positions.
            vis.images(draw_many(examples_to_viz, z_wheres))
            # Show reconstructions of data.
            vis.images(draw_many(recons, z_wheres))

        if args.eval_every > 0 and i % args.eval_every == 0:
            # Measure accuracy on subset of training data.
            acc, counts, error_z, error_ix = count_accuracy(
                X, true_counts, air, 1000)
            print('i={}, accuracy={}, counts={}'.format(
                i, acc,
                counts.numpy().tolist()))
            if args.viz and error_ix.size(0) > 0:
                vis.images(draw_many(X[error_ix[0:5]],
                                     tensor_to_objs(error_z[0:5])),
                           opts=dict(caption='errors ({})'.format(i)))

        if 'save' in args and i % args.save_every == 0:
            print('Saving parameters...')
            torch.save(air.state_dict(), args.save)
Exemple #2
0
def main(**kwargs):

    args = argparse.Namespace(**kwargs)

    if 'save' in args:
        if os.path.exists(args.save):
            raise RuntimeError('Output file "{}" already exists.'.format(args.save))

    if args.seed is not None:
        pyro.set_rng_seed(args.seed)

    X, true_counts = load_data()
    X_size = X.size(0)
    if args.cuda:
        X = X.cuda()

    # Build a function to compute z_pres prior probabilities.
    if args.z_pres_prior_raw:
        def base_z_pres_prior_p(t):
            return args.z_pres_prior
    else:
        base_z_pres_prior_p = make_prior(args.z_pres_prior)

    # Wrap with logic to apply any annealing.
    def z_pres_prior_p(opt_step, time_step):
        p = base_z_pres_prior_p(time_step)
        if args.anneal_prior == 'none':
            return p
        else:
            decay = dict(lin=lin_decay, exp=exp_decay)[args.anneal_prior]
            return decay(p, args.anneal_prior_to, args.anneal_prior_begin,
                         args.anneal_prior_duration, opt_step)

    model_arg_keys = ['window_size',
                      'rnn_hidden_size',
                      'decoder_output_bias',
                      'decoder_output_use_sigmoid',
                      'baseline_scalar',
                      'encoder_net',
                      'decoder_net',
                      'predict_net',
                      'embed_net',
                      'bl_predict_net',
                      'non_linearity',
                      'pos_prior_mean',
                      'pos_prior_sd',
                      'scale_prior_mean',
                      'scale_prior_sd']
    model_args = {key: getattr(args, key) for key in model_arg_keys if key in args}
    air = AIR(
        num_steps=args.model_steps,
        x_size=50,
        use_masking=not args.no_masking,
        use_baselines=not args.no_baselines,
        z_what_size=args.encoder_latent_size,
        use_cuda=args.cuda,
        **model_args
    )

    if args.verbose:
        print(air)
        print(args)

    if 'load' in args:
        print('Loading parameters...')
        air.load_state_dict(torch.load(args.load))

    vis = visdom.Visdom(env=args.visdom_env)
    # Viz sample from prior.
    if args.viz:
        z, x = air.prior(5, z_pres_prior_p=partial(z_pres_prior_p, 0))
        vis.images(draw_many(x, tensor_to_objs(latents_to_tensor(z))))

    def per_param_optim_args(module_name, param_name):
        lr = args.baseline_learning_rate if 'bl_' in param_name else args.learning_rate
        return {'lr': lr}

    svi = SVI(air.model, air.guide,
              optim.Adam(per_param_optim_args),
              loss=TraceGraph_ELBO())

    # Do inference.
    t0 = time.time()
    examples_to_viz = X[5:10]

    for i in range(1, args.num_steps + 1):

        loss = svi.step(X, args.batch_size, z_pres_prior_p=partial(z_pres_prior_p, i))

        if args.progress_every > 0 and i % args.progress_every == 0:
            print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format(
                i,
                (i * args.batch_size) / X_size,
                (time.time() - t0) / 3600,
                loss / X_size))

        if args.viz and i % args.viz_every == 0:
            trace = poutine.trace(air.guide).get_trace(examples_to_viz, None)
            z, recons = poutine.replay(air.prior, trace=trace)(examples_to_viz.size(0))
            z_wheres = tensor_to_objs(latents_to_tensor(z))

            # Show data with inferred objection positions.
            vis.images(draw_many(examples_to_viz, z_wheres))
            # Show reconstructions of data.
            vis.images(draw_many(recons, z_wheres))

        if args.eval_every > 0 and i % args.eval_every == 0:
            # Measure accuracy on subset of training data.
            acc, counts, error_z, error_ix = count_accuracy(X, true_counts, air, 1000)
            print('i={}, accuracy={}, counts={}'.format(i, acc, counts.numpy().tolist()))
            if args.viz and error_ix.size(0) > 0:
                vis.images(draw_many(X[error_ix[0:5]], tensor_to_objs(error_z[0:5])),
                           opts=dict(caption='errors ({})'.format(i)))

        if 'save' in args and i % args.save_every == 0:
            print('Saving parameters...')
            torch.save(air.state_dict(), args.save)
Exemple #3
0
model_args = {key: getattr(args, key) for key in model_arg_keys if key in args}
air = AIR(num_steps=args.model_steps,
          x_size=50,
          use_masking=not args.no_masking,
          use_baselines=not args.no_baselines,
          z_what_size=args.encoder_latent_size,
          use_cuda=args.cuda,
          **model_args)

if args.verbose:
    print(air)
    print(args)

if 'load' in args:
    print('Loading parameters...')
    air.load_state_dict(torch.load(args.load))

vis = visdom.Visdom(env=args.visdom_env)
# Viz sample from prior.
if args.viz:
    z, x = air.prior(5, z_pres_prior_p=partial(z_pres_prior_p, 0))
    vis.images(draw_many(x, post_process_latents(z)))

t0 = time.time()
examples_to_viz = X[9:14]


# Do inference.
def per_param_optim_args(module_name, param_name, tags):
    lr = 1e-3 if 'baseline' in tags else 1e-4
    return {'lr': lr}
Exemple #4
0
    num_steps=args.model_steps,
    x_size=50,
    use_masking=not args.no_masking,
    use_baselines=not args.no_baselines,
    z_what_size=args.encoder_latent_size,
    use_cuda=args.cuda,
    **model_args
)

if args.verbose:
    print(air)
    print(args)

if 'load' in args:
    print('Loading parameters...')
    air.load_state_dict(torch.load(args.load))

vis = visdom.Visdom(env=args.visdom_env)
# Viz sample from prior.
if args.viz:
    z, x = air.prior(5, z_pres_prior_p=partial(z_pres_prior_p, 0))
    vis.images(draw_many(x, post_process_latents(z)))

t0 = time.time()
examples_to_viz = X[9:14]


# Do inference.
def per_param_optim_args(module_name, param_name, tags):
    lr = 1e-3 if 'baseline' in tags else 1e-4
    return {'lr': lr}