Exemple #1
0
def resolve(model, inference_method=None, deep=False, **inference_args):
    if deep:
        if inference_method == 'momentum':
            raise NotImplementedError(inference_method)
            return DeepMomentumGDIR(model, **inference_args)
        elif inference_method == 'rws':
            return DeepRWS(model, **inference_args)
        elif inference_method == 'air':
            return DeepAIR(model, **inference_args)
        else:
            raise ValueError(inference_method)
    else:
        if inference_method == 'momentum':
            return MomentumGDIR(model, **inference_args)
        elif inference_method == 'rws':
            return RWS(model, **inference_args)
        elif inference_method == 'air':
            return AIR(model, **inference_args)
        else:
            raise ValueError(inference_method)
Exemple #2
0
from torch import autograd

if __name__ == '__main__':
    trainset = MultiMNIST(path=cfg.multi_mnist_path, mode='train')
    validset = MultiMNIST(path=cfg.multi_mnist_path, mode='test')
    trainloader = DataLoader(trainset,
                             batch_size=cfg.train.batch_size,
                             shuffle=True,
                             num_workers=4)
    validloader = DataLoader(validset,
                             batch_size=cfg.valid.batch_size,
                             shuffle=False,
                             num_workers=4)

    device = torch.device(cfg.device)
    model = AIR().to(device)
    optimizer = optim.Adam([{
        'params': model.air_modules.parameters(),
        'lr': cfg.train.model_lr
    }, {
        'params': model.baseline_modules.parameters(),
        'lr': cfg.train.baseline_lr
    }])

    # checkpoint
    start_epoch = 0
    checkpoint_path = os.path.join(cfg.checkpointdir, cfg.exp_name)
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    checkpointer = Checkpointer(path=checkpoint_path)
    if cfg.resume:
Exemple #3
0
def main(**kwargs):

    args = argparse.Namespace(**kwargs)

    if 'save' in args:
        if os.path.exists(args.save):
            raise RuntimeError('Output file "{}" already exists.'.format(
                args.save))

    if args.seed is not None:
        pyro.set_rng_seed(args.seed)

    X, true_counts = load_data()
    X_size = X.size(0)
    if args.cuda:
        X = X.cuda()

    # Build a function to compute z_pres prior probabilities.
    if args.z_pres_prior_raw:

        def base_z_pres_prior_p(t):
            return args.z_pres_prior
    else:
        base_z_pres_prior_p = make_prior(args.z_pres_prior)

    # Wrap with logic to apply any annealing.
    def z_pres_prior_p(opt_step, time_step):
        p = base_z_pres_prior_p(time_step)
        if args.anneal_prior == 'none':
            return p
        else:
            decay = dict(lin=lin_decay, exp=exp_decay)[args.anneal_prior]
            return decay(p, args.anneal_prior_to, args.anneal_prior_begin,
                         args.anneal_prior_duration, opt_step)

    model_arg_keys = [
        'window_size', 'rnn_hidden_size', 'decoder_output_bias',
        'decoder_output_use_sigmoid', 'baseline_scalar', 'encoder_net',
        'decoder_net', 'predict_net', 'embed_net', 'bl_predict_net',
        'non_linearity', 'pos_prior_mean', 'pos_prior_sd', 'scale_prior_mean',
        'scale_prior_sd'
    ]
    model_args = {
        key: getattr(args, key)
        for key in model_arg_keys if key in args
    }
    air = AIR(num_steps=args.model_steps,
              x_size=50,
              use_masking=not args.no_masking,
              use_baselines=not args.no_baselines,
              z_what_size=args.encoder_latent_size,
              use_cuda=args.cuda,
              **model_args)

    if args.verbose:
        print(air)
        print(args)

    if 'load' in args:
        print('Loading parameters...')
        air.load_state_dict(torch.load(args.load))

    vis = visdom.Visdom(env=args.visdom_env)
    # Viz sample from prior.
    if args.viz:
        z, x = air.prior(5, z_pres_prior_p=partial(z_pres_prior_p, 0))
        vis.images(draw_many(x, tensor_to_objs(latents_to_tensor(z))))

    def per_param_optim_args(module_name, param_name, tags):
        lr = args.baseline_learning_rate if 'baseline' in tags else args.learning_rate
        return {'lr': lr}

    svi = SVI(air.model,
              air.guide,
              optim.Adam(per_param_optim_args),
              loss='ELBO',
              trace_graph=True)

    # Do inference.
    t0 = time.time()
    examples_to_viz = X[5:10]

    for i in range(1, args.num_steps + 1):

        loss = svi.step(X,
                        args.batch_size,
                        z_pres_prior_p=partial(z_pres_prior_p, i))

        if args.progress_every > 0 and i % args.progress_every == 0:
            print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format(
                i, (i * args.batch_size) / X_size, (time.time() - t0) / 3600,
                loss / X_size))

        if args.viz and i % args.viz_every == 0:
            trace = poutine.trace(air.guide).get_trace(examples_to_viz, None)
            z, recons = poutine.replay(air.prior,
                                       trace)(examples_to_viz.size(0))
            z_wheres = tensor_to_objs(latents_to_tensor(z))

            # Show data with inferred objection positions.
            vis.images(draw_many(examples_to_viz, z_wheres))
            # Show reconstructions of data.
            vis.images(draw_many(recons, z_wheres))

        if args.eval_every > 0 and i % args.eval_every == 0:
            # Measure accuracy on subset of training data.
            acc, counts, error_z, error_ix = count_accuracy(
                X, true_counts, air, 1000)
            print('i={}, accuracy={}, counts={}'.format(
                i, acc,
                counts.numpy().tolist()))
            if args.viz and error_ix.size(0) > 0:
                vis.images(draw_many(X[error_ix[0:5]],
                                     tensor_to_objs(error_z[0:5])),
                           opts=dict(caption='errors ({})'.format(i)))

        if 'save' in args and i % args.save_every == 0:
            print('Saving parameters...')
            torch.save(air.state_dict(), args.save)
Exemple #4
0
        decay = dict(lin=lin_decay, exp=exp_decay)[args.anneal_prior]
        return decay(p, args.anneal_prior_to, args.anneal_prior_begin,
                     args.anneal_prior_duration, opt_step)


model_arg_keys = [
    'window_size', 'rnn_hidden_size', 'decoder_output_bias',
    'decoder_output_use_sigmoid', 'baseline_scalar', 'encoder_net',
    'decoder_net', 'predict_net', 'embed_net', 'bl_predict_net',
    'non_linearity', 'fudge_z_pres'
]
model_args = {key: getattr(args, key) for key in model_arg_keys if key in args}
air = AIR(num_steps=args.model_steps,
          x_size=50,
          use_masking=not args.no_masking,
          use_baselines=not args.no_baselines,
          z_what_size=args.encoder_latent_size,
          use_cuda=args.cuda,
          **model_args)

if args.verbose:
    print(air)
    print(args)

if 'load' in args:
    print('Loading parameters...')
    air.load_state_dict(torch.load(args.load))

vis = visdom.Visdom(env=args.visdom_env)
# Viz sample from prior.
if args.viz:
Exemple #5
0
def main(**kwargs):

    args = argparse.Namespace(**kwargs)

    if 'save' in args:
        if os.path.exists(args.save):
            raise RuntimeError('Output file "{}" already exists.'.format(args.save))

    if args.seed is not None:
        pyro.set_rng_seed(args.seed)

    X, true_counts = load_data()
    X_size = X.size(0)
    if args.cuda:
        X = X.cuda()

    # Build a function to compute z_pres prior probabilities.
    if args.z_pres_prior_raw:
        def base_z_pres_prior_p(t):
            return args.z_pres_prior
    else:
        base_z_pres_prior_p = make_prior(args.z_pres_prior)

    # Wrap with logic to apply any annealing.
    def z_pres_prior_p(opt_step, time_step):
        p = base_z_pres_prior_p(time_step)
        if args.anneal_prior == 'none':
            return p
        else:
            decay = dict(lin=lin_decay, exp=exp_decay)[args.anneal_prior]
            return decay(p, args.anneal_prior_to, args.anneal_prior_begin,
                         args.anneal_prior_duration, opt_step)

    model_arg_keys = ['window_size',
                      'rnn_hidden_size',
                      'decoder_output_bias',
                      'decoder_output_use_sigmoid',
                      'baseline_scalar',
                      'encoder_net',
                      'decoder_net',
                      'predict_net',
                      'embed_net',
                      'bl_predict_net',
                      'non_linearity',
                      'pos_prior_mean',
                      'pos_prior_sd',
                      'scale_prior_mean',
                      'scale_prior_sd']
    model_args = {key: getattr(args, key) for key in model_arg_keys if key in args}
    air = AIR(
        num_steps=args.model_steps,
        x_size=50,
        use_masking=not args.no_masking,
        use_baselines=not args.no_baselines,
        z_what_size=args.encoder_latent_size,
        use_cuda=args.cuda,
        **model_args
    )

    if args.verbose:
        print(air)
        print(args)

    if 'load' in args:
        print('Loading parameters...')
        air.load_state_dict(torch.load(args.load))

    vis = visdom.Visdom(env=args.visdom_env)
    # Viz sample from prior.
    if args.viz:
        z, x = air.prior(5, z_pres_prior_p=partial(z_pres_prior_p, 0))
        vis.images(draw_many(x, tensor_to_objs(latents_to_tensor(z))))

    def per_param_optim_args(module_name, param_name):
        lr = args.baseline_learning_rate if 'bl_' in param_name else args.learning_rate
        return {'lr': lr}

    svi = SVI(air.model, air.guide,
              optim.Adam(per_param_optim_args),
              loss=TraceGraph_ELBO())

    # Do inference.
    t0 = time.time()
    examples_to_viz = X[5:10]

    for i in range(1, args.num_steps + 1):

        loss = svi.step(X, args.batch_size, z_pres_prior_p=partial(z_pres_prior_p, i))

        if args.progress_every > 0 and i % args.progress_every == 0:
            print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format(
                i,
                (i * args.batch_size) / X_size,
                (time.time() - t0) / 3600,
                loss / X_size))

        if args.viz and i % args.viz_every == 0:
            trace = poutine.trace(air.guide).get_trace(examples_to_viz, None)
            z, recons = poutine.replay(air.prior, trace=trace)(examples_to_viz.size(0))
            z_wheres = tensor_to_objs(latents_to_tensor(z))

            # Show data with inferred objection positions.
            vis.images(draw_many(examples_to_viz, z_wheres))
            # Show reconstructions of data.
            vis.images(draw_many(recons, z_wheres))

        if args.eval_every > 0 and i % args.eval_every == 0:
            # Measure accuracy on subset of training data.
            acc, counts, error_z, error_ix = count_accuracy(X, true_counts, air, 1000)
            print('i={}, accuracy={}, counts={}'.format(i, acc, counts.numpy().tolist()))
            if args.viz and error_ix.size(0) > 0:
                vis.images(draw_many(X[error_ix[0:5]], tensor_to_objs(error_z[0:5])),
                           opts=dict(caption='errors ({})'.format(i)))

        if 'save' in args and i % args.save_every == 0:
            print('Saving parameters...')
            torch.save(air.state_dict(), args.save)
Exemple #6
0
                  'decoder_output_bias',
                  'decoder_output_use_sigmoid',
                  'baseline_scalar',
                  'encoder_net',
                  'decoder_net',
                  'predict_net',
                  'embed_net',
                  'bl_predict_net',
                  'non_linearity',
                  'fudge_z_pres']
model_args = {key: getattr(args, key) for key in model_arg_keys if key in args}
air = AIR(
    num_steps=args.model_steps,
    x_size=50,
    use_masking=not args.no_masking,
    use_baselines=not args.no_baselines,
    z_what_size=args.encoder_latent_size,
    use_cuda=args.cuda,
    **model_args
)

if args.verbose:
    print(air)
    print(args)

if 'load' in args:
    print('Loading parameters...')
    air.load_state_dict(torch.load(args.load))

vis = visdom.Visdom(env=args.visdom_env)
# Viz sample from prior.