def main_loop(db, bm, inv_bm, score_func, sampler, proposal_dist, fn_sample_init, fn_log_prob, proposal_opt=None):
    gibbs_sampler = GibbsSampler(2, cmd_args.discrete_dim, cmd_args.device)

    rand_samples = torch.randint(2, (1000, cmd_args.discrete_dim)).to(cmd_args.device)

    if cmd_args.energy_model_dump is not None:
        print('loading score_func from', cmd_args.energy_model_dump)
        score_func.load_state_dict(torch.load(cmd_args.energy_model_dump, map_location=cmd_args.device))
    if cmd_args.sampler_model_dump is not None:
        print('loading sampler from', cmd_args.sampler_model_dump)
        sampler.load_state_dict(torch.load(cmd_args.sampler_model_dump, map_location=cmd_args.device))

    samples = float2bin(db.gen_batch(1000), bm)
    samples = torch.from_numpy(samples).to(cmd_args.device)
    print('true score: %.4f' % torch.mean(score_func(samples)).item())

    opt_score = optim.Adam(score_func.parameters(), lr=cmd_args.learning_rate * cmd_args.f_lr_scale)
    opt_sampler = optim.Adam(sampler.parameters(), lr=cmd_args.learning_rate)

    if cmd_args.phase == 'plot':
        cmd_args.plot_size = 4.1
        plot_heat(score_func, bm, out_file=os.path.join(cmd_args.save_dir, '%s-heat.pdf' % cmd_args.data))
        plot_sampler(0, fn_sample_init, inv_bm, os.path.join(cmd_args.save_dir, '%s-init.pdf' % cmd_args.data))
        plot_sampler(cmd_args.num_q_steps, fn_sample_init, inv_bm, os.path.join(cmd_args.save_dir, '%s-edit.pdf' % cmd_args.data))
        sys.exit()

    for epoch in range(cmd_args.num_epochs):
        pbar = tqdm(range(cmd_args.iter_per_epoch))
        for it in pbar:
            samples = float2bin(db.gen_batch(cmd_args.batch_size), bm)
            samples = torch.from_numpy(samples).to(cmd_args.device)

            if cmd_args.learn_mode == 'ebm':
                neg_samples, avg_steps = get_samples(fn_sample_init, gibbs_sampler, score_func, proposal_opt)
                f_loss = learn_score(samples, score_func, opt_score, neg_samples=neg_samples)
            else:
                neg_samples = samples
                f_loss = 0.0

            for q_it in range(cmd_args.q_iter):
                opt_sampler.zero_grad()
                with torch.no_grad():
                    neg_samples = neg_samples.repeat(cmd_args.num_importance_samples, 1)
                    cur_samples, proposal_logprob = proposal_dist(cmd_args.num_q_steps, neg_samples)
                    diff_pos = (cur_samples - neg_samples).nonzero()
                    rep_rows = diff_pos[:, 0].view(-1)
                    rep_init = torch.index_select(cur_samples, 0, rep_rows)
                    row_ids, col_ids, col_target, traj_lens = prepare_diff_pos(diff_pos)
                    rep_init[row_ids, col_ids] = 1 - rep_init[row_ids, col_ids]
                    rep_target = torch.LongTensor(col_target).to(cmd_args.device).view(-1, 1)

                init_prob = sampler.base_logprob(cur_samples)
                if rep_rows.shape[0]:
                    traj_prob, _ = sampler.forward_onestep(init_samples=rep_init, target_pos=rep_target)
                    traj_prob = scatter_add(traj_prob, rep_rows, dim=0, dim_size=cur_samples.shape[0])
                else:
                    traj_prob = 0
                log_prob = fn_log_prob(init_prob, traj_prob, rep_init, rep_rows, neg_samples)  #TODO: shouldn't be neg_samples?

                if cmd_args.lb_type == 'is':  # calc weights using self-normalization
                    with torch.no_grad():
                        log_ratio = (log_prob - proposal_logprob).view(cmd_args.num_importance_samples, -1)
                        weight = F.softmax(log_ratio, dim=0).view(log_prob.shape)
                else:
                    weight = 1.0 / cmd_args.num_importance_samples
                log_prob = log_prob * weight

                loss = -torch.mean(log_prob) * cmd_args.num_importance_samples
                loss.backward()
                if cmd_args.grad_clip > 0:
                    torch.nn.utils.clip_grad_norm_(sampler.parameters(), max_norm=cmd_args.grad_clip)
                opt_sampler.step()

                if q_it + 1 < cmd_args.q_iter:
                    neg_samples, _ = get_samples(fn_sample_init, gibbs_sampler, score_func, proposal_opt=None)

            g_loss = loss.item()
            true_score = torch.mean(score_func(samples)).item()
            fake_score = torch.mean(score_func(neg_samples)).item()
            rand_score = torch.mean(score_func(rand_samples)).item()
            pbar.set_description('epoch: %d, f: %.2f, g: %.2f, n: %.2f, true: %.2f, fake: %.2f, rand: %.2f' % (epoch, f_loss, g_loss, avg_steps, true_score, fake_score, rand_score))

        if epoch and epoch % cmd_args.epoch_save == 0:
            torch.save(score_func.state_dict(), os.path.join(cmd_args.save_dir, 'score_func-%d.ckpt' % epoch))
            torch.save(sampler.state_dict(), os.path.join(cmd_args.save_dir, 'sampler-%d.ckpt' % epoch))

            plot_heat(score_func, bm, out_file=os.path.join(cmd_args.save_dir, 'heat-%d.pdf' % epoch))
            for n_step in [0, cmd_args.num_q_steps]:
                plot_sampler(n_step, fn_sample_init, inv_bm, os.path.join(cmd_args.save_dir, 'rand-samples-%d-%d.pdf' % (epoch, n_step)))
Пример #2
0
            except StopIteration:
                train_gen = iter(train_load)
                samples = next(train_gen)
            pos_list, hex_stream = samples
            pos_list = pos_list.to(cmd_args.device)
            hex_stream = hex_stream.to(cmd_args.device)
            # get samples
            with torch.no_grad():
                init_samples, n_steps, _, _ = sampler(cmd_args.num_q_steps,
                                                      pos_list)
                cur_score_fn = lambda samples: score_func(pos_list, samples)
                neg_samples = gibbs_sampler(cur_score_fn,
                                            cmd_args.gibbs_rounds,
                                            init_samples=init_samples)

            f_loss, true_scores = learn_score(pos_list, hex_stream, score_func,
                                              opt_score, neg_samples)

            with torch.no_grad():
                rand_scores = torch.mean(score_func(pos_list,
                                                    rand_samples)).item()
                neg_scores = torch.mean(score_func(pos_list,
                                                   neg_samples)).item()

            neg_samples = neg_samples.repeat(cmd_args.num_importance_samples,
                                             1)
            pos_list = pos_list.repeat(cmd_args.num_importance_samples, 1)
            for q_it in range(cmd_args.q_iter):
                opt_sampler.zero_grad()
                if cmd_args.num_q_steps:
                    with torch.no_grad():
                        cur_samples, proposal_logprob = proposal_dist(