示例#1
0
def main(**kwargs):

    args = argparse.Namespace(**kwargs)

    if 'save' in args:
        if os.path.exists(args.save):
            raise RuntimeError('Output file "{}" already exists.'.format(
                args.save))

    if args.seed is not None:
        pyro.set_rng_seed(args.seed)

    X, true_counts = load_data()
    X_size = X.size(0)
    if args.cuda:
        X = X.cuda()

    # Build a function to compute z_pres prior probabilities.
    if args.z_pres_prior_raw:

        def base_z_pres_prior_p(t):
            return args.z_pres_prior
    else:
        base_z_pres_prior_p = make_prior(args.z_pres_prior)

    # Wrap with logic to apply any annealing.
    def z_pres_prior_p(opt_step, time_step):
        p = base_z_pres_prior_p(time_step)
        if args.anneal_prior == 'none':
            return p
        else:
            decay = dict(lin=lin_decay, exp=exp_decay)[args.anneal_prior]
            return decay(p, args.anneal_prior_to, args.anneal_prior_begin,
                         args.anneal_prior_duration, opt_step)

    model_arg_keys = [
        'window_size', 'rnn_hidden_size', 'decoder_output_bias',
        'decoder_output_use_sigmoid', 'baseline_scalar', 'encoder_net',
        'decoder_net', 'predict_net', 'embed_net', 'bl_predict_net',
        'non_linearity', 'pos_prior_mean', 'pos_prior_sd', 'scale_prior_mean',
        'scale_prior_sd'
    ]
    model_args = {
        key: getattr(args, key)
        for key in model_arg_keys if key in args
    }
    air = AIR(num_steps=args.model_steps,
              x_size=50,
              use_masking=not args.no_masking,
              use_baselines=not args.no_baselines,
              z_what_size=args.encoder_latent_size,
              use_cuda=args.cuda,
              **model_args)

    if args.verbose:
        print(air)
        print(args)

    if 'load' in args:
        print('Loading parameters...')
        air.load_state_dict(torch.load(args.load))

    vis = visdom.Visdom(env=args.visdom_env)
    # Viz sample from prior.
    if args.viz:
        z, x = air.prior(5, z_pres_prior_p=partial(z_pres_prior_p, 0))
        vis.images(draw_many(x, tensor_to_objs(latents_to_tensor(z))))

    def per_param_optim_args(module_name, param_name, tags):
        lr = args.baseline_learning_rate if 'baseline' in tags else args.learning_rate
        return {'lr': lr}

    svi = SVI(air.model,
              air.guide,
              optim.Adam(per_param_optim_args),
              loss='ELBO',
              trace_graph=True)

    # Do inference.
    t0 = time.time()
    examples_to_viz = X[5:10]

    for i in range(1, args.num_steps + 1):

        loss = svi.step(X,
                        args.batch_size,
                        z_pres_prior_p=partial(z_pres_prior_p, i))

        if args.progress_every > 0 and i % args.progress_every == 0:
            print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format(
                i, (i * args.batch_size) / X_size, (time.time() - t0) / 3600,
                loss / X_size))

        if args.viz and i % args.viz_every == 0:
            trace = poutine.trace(air.guide).get_trace(examples_to_viz, None)
            z, recons = poutine.replay(air.prior,
                                       trace)(examples_to_viz.size(0))
            z_wheres = tensor_to_objs(latents_to_tensor(z))

            # Show data with inferred objection positions.
            vis.images(draw_many(examples_to_viz, z_wheres))
            # Show reconstructions of data.
            vis.images(draw_many(recons, z_wheres))

        if args.eval_every > 0 and i % args.eval_every == 0:
            # Measure accuracy on subset of training data.
            acc, counts, error_z, error_ix = count_accuracy(
                X, true_counts, air, 1000)
            print('i={}, accuracy={}, counts={}'.format(
                i, acc,
                counts.numpy().tolist()))
            if args.viz and error_ix.size(0) > 0:
                vis.images(draw_many(X[error_ix[0:5]],
                                     tensor_to_objs(error_z[0:5])),
                           opts=dict(caption='errors ({})'.format(i)))

        if 'save' in args and i % args.save_every == 0:
            print('Saving parameters...')
            torch.save(air.state_dict(), args.save)
示例#2
0
          use_cuda=args.cuda,
          **model_args)

if args.verbose:
    print(air)
    print(args)

if 'load' in args:
    print('Loading parameters...')
    air.load_state_dict(torch.load(args.load))

vis = visdom.Visdom(env=args.visdom_env)
# Viz sample from prior.
if args.viz:
    z, x = air.prior(5, z_pres_prior_p=partial(z_pres_prior_p, 0))
    vis.images(draw_many(x, post_process_latents(z)))

t0 = time.time()
examples_to_viz = X[9:14]


# Do inference.
def per_param_optim_args(module_name, param_name, tags):
    lr = 1e-3 if 'baseline' in tags else 1e-4
    return {'lr': lr}


svi = SVI(air.model,
          air.guide,
          optim.Adam(per_param_optim_args),
          loss='ELBO',
示例#3
0
文件: main.py 项目: lewisKit/pyro
def main(**kwargs):

    args = argparse.Namespace(**kwargs)

    if 'save' in args:
        if os.path.exists(args.save):
            raise RuntimeError('Output file "{}" already exists.'.format(args.save))

    if args.seed is not None:
        pyro.set_rng_seed(args.seed)

    X, true_counts = load_data()
    X_size = X.size(0)
    if args.cuda:
        X = X.cuda()

    # Build a function to compute z_pres prior probabilities.
    if args.z_pres_prior_raw:
        def base_z_pres_prior_p(t):
            return args.z_pres_prior
    else:
        base_z_pres_prior_p = make_prior(args.z_pres_prior)

    # Wrap with logic to apply any annealing.
    def z_pres_prior_p(opt_step, time_step):
        p = base_z_pres_prior_p(time_step)
        if args.anneal_prior == 'none':
            return p
        else:
            decay = dict(lin=lin_decay, exp=exp_decay)[args.anneal_prior]
            return decay(p, args.anneal_prior_to, args.anneal_prior_begin,
                         args.anneal_prior_duration, opt_step)

    model_arg_keys = ['window_size',
                      'rnn_hidden_size',
                      'decoder_output_bias',
                      'decoder_output_use_sigmoid',
                      'baseline_scalar',
                      'encoder_net',
                      'decoder_net',
                      'predict_net',
                      'embed_net',
                      'bl_predict_net',
                      'non_linearity',
                      'pos_prior_mean',
                      'pos_prior_sd',
                      'scale_prior_mean',
                      'scale_prior_sd']
    model_args = {key: getattr(args, key) for key in model_arg_keys if key in args}
    air = AIR(
        num_steps=args.model_steps,
        x_size=50,
        use_masking=not args.no_masking,
        use_baselines=not args.no_baselines,
        z_what_size=args.encoder_latent_size,
        use_cuda=args.cuda,
        **model_args
    )

    if args.verbose:
        print(air)
        print(args)

    if 'load' in args:
        print('Loading parameters...')
        air.load_state_dict(torch.load(args.load))

    vis = visdom.Visdom(env=args.visdom_env)
    # Viz sample from prior.
    if args.viz:
        z, x = air.prior(5, z_pres_prior_p=partial(z_pres_prior_p, 0))
        vis.images(draw_many(x, tensor_to_objs(latents_to_tensor(z))))

    def per_param_optim_args(module_name, param_name):
        lr = args.baseline_learning_rate if 'bl_' in param_name else args.learning_rate
        return {'lr': lr}

    svi = SVI(air.model, air.guide,
              optim.Adam(per_param_optim_args),
              loss=TraceGraph_ELBO())

    # Do inference.
    t0 = time.time()
    examples_to_viz = X[5:10]

    for i in range(1, args.num_steps + 1):

        loss = svi.step(X, args.batch_size, z_pres_prior_p=partial(z_pres_prior_p, i))

        if args.progress_every > 0 and i % args.progress_every == 0:
            print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format(
                i,
                (i * args.batch_size) / X_size,
                (time.time() - t0) / 3600,
                loss / X_size))

        if args.viz and i % args.viz_every == 0:
            trace = poutine.trace(air.guide).get_trace(examples_to_viz, None)
            z, recons = poutine.replay(air.prior, trace=trace)(examples_to_viz.size(0))
            z_wheres = tensor_to_objs(latents_to_tensor(z))

            # Show data with inferred objection positions.
            vis.images(draw_many(examples_to_viz, z_wheres))
            # Show reconstructions of data.
            vis.images(draw_many(recons, z_wheres))

        if args.eval_every > 0 and i % args.eval_every == 0:
            # Measure accuracy on subset of training data.
            acc, counts, error_z, error_ix = count_accuracy(X, true_counts, air, 1000)
            print('i={}, accuracy={}, counts={}'.format(i, acc, counts.numpy().tolist()))
            if args.viz and error_ix.size(0) > 0:
                vis.images(draw_many(X[error_ix[0:5]], tensor_to_objs(error_z[0:5])),
                           opts=dict(caption='errors ({})'.format(i)))

        if 'save' in args and i % args.save_every == 0:
            print('Saving parameters...')
            torch.save(air.state_dict(), args.save)
示例#4
0
def generate_image(count, tau, gen_batch_size=64):
    #x = model.generate(batch_size)
    train_iter = iter(train_loader)
    (data, _) = train_iter.next()
    data = data.to(device)
    recon_x, z_pres, cats = model.recon(data, tau)
    z_zip = model.get_where_pres()
    z_obj = tensor_to_objs(latents_to_tensor(z_zip)[:5])
    #print(len(recon_x))
    recon_x_final = recon_x[-1]
    print(np.min(recon_x_final), np.max(recon_x_final))
    cats_array = []
    z_pres_array = []

    if count >= 60000:
        for i in range(3):
            cats_numpy = cats[i].detach().cpu().numpy()
            index = np.argmax(cats_numpy, axis=-1)

            z_pres_np = np.squeeze(
                torch.round(z_pres[i].detach()).cpu().numpy())
            z_pres_np_raw = np.squeeze(z_pres[i].detach().cpu().numpy())
            cats_array.append(cats_numpy)
            z_pres_array.append(z_pres_np_raw)

            print(index.shape, z_pres_np.shape)
            index = index * z_pres_np + z_pres_np - 1
            print('i=', i, ' cat index:')
            print(np.reshape(index, (8, 8)))

    np.savez_compressed('codes/recon_code_{}.csv'.format(count),
                        cats=np.array(cats_array),
                        z_pres=np.array(z_pres_array))

    vis.images(draw_many(data[:5].view(-1, A, B), z_obj))
    # Show reconstructions of data.
    vis.images(draw_many(
        torch.tensor(recon_x_final[:5]).view(-1, A, B), z_obj))
    #save_image(x,count,'gen')
    save_image(recon_x, count, 'recon', path='image/')
    save_image_single(data.cpu().numpy(), count, 'origin', path='image/')
    #####generate image from scratch
    #first_obj = np.random.randint(cat_size)
    first_obj = 0
    first_obj_tensor = torch.zeros(gen_batch_size, cat_size).scatter_(
        1,
        torch.tensor(first_obj).expand(gen_batch_size, 1),
        1).float().to(device)

    #second_obj = np.random.randint(cat_size)
    second_obj = 1
    second_obj_tensor = torch.zeros(gen_batch_size, cat_size).scatter_(
        1,
        torch.tensor(second_obj).expand(gen_batch_size, 1),
        1).float().to(device)

    #third_obj = np.random.randint(cat_size)
    third_obj = 2
    third_obj_tensor = torch.zeros(gen_batch_size, cat_size).scatter_(
        1,
        torch.tensor(third_obj).expand(gen_batch_size, 1),
        1).float().to(device)

    first_scale = np.random.uniform(1.5, 4, size=(gen_batch_size, 1))
    first_pos = np.random.uniform(-0.3, 0.3, size=(gen_batch_size, 2))
    second_scale = np.random.uniform(1.5, 4, size=(gen_batch_size, 1))
    second_pos = np.random.uniform(-0.3, 0.3, size=(gen_batch_size, 2))
    third_scale = np.random.uniform(1.5, 4, size=(gen_batch_size, 1))
    third_pos = np.random.uniform(-0.3, 0.3, size=(gen_batch_size, 2))
    #third_pos = np.concatenate([np.random.uniform(0.6,0.8,size=(gen_batch_size,1)),np.random.uniform(-0.6,-0.8,size=(gen_batch_size,1))],axis=-1)
    first_z_where = torch.from_numpy(
        np.concatenate([first_scale, first_pos], axis=-1)).float().to(device)
    second_z_where = torch.from_numpy(
        np.concatenate([second_scale, second_pos], axis=-1)).float().to(device)
    third_z_where = torch.from_numpy(
        np.concatenate([third_scale, third_pos], axis=-1)).float().to(device)

    gen_x = model.generate(
        [first_obj_tensor, second_obj_tensor, third_obj_tensor],
        [first_z_where, second_z_where, third_z_where])
    save_image(gen_x, count, 'gen', path='image/')
示例#5
0
文件: main.py 项目: Magica-Chen/pyro
    **model_args
)

if args.verbose:
    print(air)
    print(args)

if 'load' in args:
    print('Loading parameters...')
    air.load_state_dict(torch.load(args.load))

vis = visdom.Visdom(env=args.visdom_env)
# Viz sample from prior.
if args.viz:
    z, x = air.prior(5, z_pres_prior_p=partial(z_pres_prior_p, 0))
    vis.images(draw_many(x, post_process_latents(z)))

t0 = time.time()
examples_to_viz = X[9:14]


# Do inference.
def per_param_optim_args(module_name, param_name, tags):
    lr = 1e-3 if 'baseline' in tags else 1e-4
    return {'lr': lr}


svi = SVI(air.model, air.guide,
          optim.Adam(per_param_optim_args),
          loss='ELBO',
          trace_graph=True)