コード例 #1
0
def do_plot(flow, epoch_idx):
    flow.train(False)

    # When using multiple GPUs, each GPU samples batch_size / device_count
    sample, _ = flow.sample(args.batch_size // args.device_count)
    my_log('plot min {:.3g} max {:.3g} mean {:.3g} std {:.3g}'.format(
        sample.min().item(),
        sample.max().item(),
        sample.mean().item(),
        sample.std().item(),
    ))
    sample, _ = utils.logit_transform(sample, inverse=True)
    sample = torch.clamp(sample, 0, 1)
    sample = sample.permute(0, 2, 3, 1).detach().cpu().numpy()

    fig, axes = plot_samples_np(sample)

    fig.suptitle('{}/{}/epoch{}'.format(args.data, args.net_name, epoch_idx))
    my_tight_layout(fig)
    plot_filename = '{}/epoch{}.pdf'.format(args.plot_filename, epoch_idx)
    utils.ensure_dir(plot_filename)
    fig.savefig(plot_filename, bbox_inches='tight')
    fig.clf()
    plt.close()

    flow.train(True)
コード例 #2
0
def main():
    utils.init_out_dir()
    last_epoch = utils.get_last_checkpoint_step()
    if last_epoch >= args.epoch:
        exit()
    if last_epoch >= 0:
        my_log('\nCheckpoint found: {}\n'.format(last_epoch))
    else:
        utils.clear_log()

    model = RNN(args.device, Number_qubits = args.N,charset_length = args.charset_length,\
        hidden_size = args.hidden_size, num_layers = args.num_layers)

    model.train(False)
    print('number of qubits: ', model.Number_qubits)
    my_log('Total nparams: {}'.format(utils.get_nparams(model)))
    model.to(args.device)
    params = [x for x in model.parameters() if x.requires_grad]
    optimizer = torch.optim.AdamW(params,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    if last_epoch >= 0:
        utils.load_checkpoint(last_epoch, model, optimizer)

    # Quantum state
    ghz = GHZ(Number_qubits=args.N)
    c_fidelity = classical_fidelity(model, ghz)
    # c_fidelity = cfid(model, ghz, './data.txt')
    print(c_fidelity)
コード例 #3
0
ファイル: check.py プロジェクト: lyutyuh/dlpkuhole2
def check_file(filename):
    post_list = list(reversed(read_posts(filename)))
    if not post_list:
        return None, None

    last_pid = None
    for post in post_list:
        pid = post['pid']
        time_str = datetime.fromtimestamp(
            post['timestamp']).strftime('%Y-%m-%d %H:%M:%S')
        if last_pid and pid > last_pid + 1 and pid < last_pid + max_missed_pid:
            for i in range(last_pid + 1, pid):
                my_log('{} {} REALLY MISSED'.format(i, time_str))
        last_pid = pid
        first_line = post['text'].splitlines()[0]
        if first_line == '#DELETED':
            my_log('{} {} DELETED'.format(pid, time_str))
        if first_line == '#MISSED':
            my_log('{} {} MISSED'.format(pid, time_str))
        if default_reply is not False and (
                post['reply'] != default_reply
                and post['reply'] != len(post['comments'])):
            my_log('{} {} REPLY NOT MATCH {} {}'.format(
                pid, time_str, post['reply'], len(post['comments'])))

    oldest_pid = post_list[0]['pid']
    newest_pid = post_list[-1]['pid']
    return oldest_pid, newest_pid
コード例 #4
0
def validation(load_path, result_dir, generate_num, seed):
    # make results dir
    if not os.path.exists(result_dir):
        os.mkdir(result_dir)

    files = os.listdir(result_dir)
    results_number = len(files)

    #set dir name
    desc = 'DCGAN'
    desc += '_generate'
    save_dir = os.path.join(result_dir, '%04d_' % (results_number) + desc)
    os.mkdir(save_dir)

    # set my logger
    log = utils.my_log(os.path.join(save_dir, 'results.txt'))
    log.logging(utils.SPLIT_LINE)
    log.logging('< Validation >')

    # set seed
    np.random.seed(seed)  #default 22222

    # validation set
    z_dim = 100

    # input placeholder
    g_in = tf.keras.layers.Input(shape=(z_dim), name='G_input')

    # set model
    GAN = tf.keras.models.load_model(load_path)
    G = GAN.layers[1]
    log.logging('[' + load_path + '] model loaded !!')

    for idx in range(generate_num):
        z = np.random.normal(size=[1, z_dim])
        fake_img = G.predict(z)[0]

        # pixel range : -1 ~ 1 -> 0 ~ 255
        fake_img = ((fake_img + 1) * 127.5).astype(np.uint8)
        log.logging('[%d/%d] Generate image!! [seed:%d]' %
                    (idx, generate_num, seed))
        cv2.imwrite(
            os.path.join(save_dir, 'fake%04d_seed%d.png' % (idx, seed)),
            fake_img)
コード例 #5
0
    post_list1 = read_posts(filename.replace(input_dir2, input_dir1))
    post_list2 = read_posts(filename)
    out_list = []
    i = 0
    j = 0
    while i < len(post_list1) and j < len(post_list2):
        if post_list1[i]['pid'] > post_list2[j]['pid']:
            out_list.append(post_list1[i])
            i += 1
        elif post_list1[i]['pid'] < post_list2[j]['pid']:
            out_list.append(post_list2[j])
            j += 1
        else:
            if cmp(post_list1[i], post_list2[j]):
                out_list.append(post_list1[i])
            else:
                out_list.append(post_list2[j])
            i += 1
            j += 1
    out_list += post_list1[i:]
    out_list += post_list2[j:]
    write_posts(filename.replace(input_dir2, output_dir), out_list)


if __name__ == '__main__':
    for root, dirs, files in os.walk(input_dir2):
        for file in sorted(files):
            filename = os.path.join(root, file)
            my_log(filename)
            merge_file(filename)
コード例 #6
0
def main():
    start_time = time.time()

    init_out_dir()
    print_args()
    args.n = 60  ##n=60

    if args.ham == 'hop':
        ham = HopfieldModel(args.n, args.beta, args.device, seed=args.seed)
    elif args.ham == 'sk':
        ham = SKModel(args.n, args.beta, args.device, seed=args.seed)
    else:
        raise ValueError('Unknown ham: {}'.format(args.ham))
    ham.J.requires_grad = False

    net = MADE(**vars(args))
    net.to(args.device)
    my_log('{}\n'.format(net))

    params = list(net.parameters())
    params = list(filter(lambda p: p.requires_grad, params))
    nparams = int(sum([np.prod(p.shape) for p in params]))
    my_log('Total number of trainable parameters: {}'.format(nparams))

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(params, lr=args.lr)
    elif args.optimizer == 'sgdm':
        optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(params, lr=args.lr, alpha=0.99)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.9, 0.999))
    elif args.optimizer == 'adam0.5':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.5, 0.999))
    else:
        raise ValueError('Unknown optimizer: {}'.format(args.optimizer))

    init_time = time.time() - start_time
    my_log('init_time = {:.3f}'.format(init_time))

    my_log('Training...')
    sample_time = 0
    train_time = 0
    start_time = time.time()
    if args.beta_anneal_to < args.beta:
        args.beta_anneal_to = args.beta
    beta = args.beta
    while beta <= args.beta_anneal_to:
        for step in range(args.max_step):
            optimizer.zero_grad()

            sample_start_time = time.time()
            with torch.no_grad():
                sample, x_hat = net.sample(args.batch_size)
            assert not sample.requires_grad
            assert not x_hat.requires_grad
            sample_time += time.time() - sample_start_time

            train_start_time = time.time()

            log_prob = net.log_prob(sample)
            with torch.no_grad():
                energy = ham.energy(sample)
                #print('fffffffffffffffffff',energy.type())
                loss = log_prob + beta * energy
            assert not energy.requires_grad
            assert not loss.requires_grad
            loss_reinforce = torch.mean((loss - loss.mean()) * log_prob)
            loss_reinforce.backward()

            if args.clip_grad > 0:
                # nn.utils.clip_grad_norm_(params, args.clip_grad)
                parameters = list(filter(lambda p: p.grad is not None, params))
                max_norm = float(args.clip_grad)
                norm_type = 2
                total_norm = 0
                for p in parameters:
                    param_norm = p.grad.data.norm(norm_type)
                    total_norm += param_norm.item()**norm_type
                    total_norm = total_norm**(1 / norm_type)
                    clip_coef = max_norm / (total_norm + args.epsilon)
                    for p in parameters:
                        p.grad.data.mul_(clip_coef)

            optimizer.step()

            train_time += time.time() - train_start_time

            if args.print_step and step % args.print_step == 0:
                free_energy_mean = loss.mean() / beta / args.n
                free_energy_std = loss.std() / beta / args.n
                entropy_mean = -log_prob.mean() / args.n
                energy_mean = energy.mean() / args.n
                mag = sample.mean(dim=0)
                mag_mean = mag.mean()
                if step > 0:
                    sample_time /= args.print_step
                    train_time /= args.print_step
                used_time = time.time() - start_time
                my_log(
                    'beta = {:.3g}, # {}, F = {:.8g}, F_std = {:.8g}, S = {:.5g}, E = {:.5g}, M = {:.5g}, sample_time = {:.3f}, train_time = {:.3f}, used_time = {:.3f}'
                    .format(
                        beta,
                        step,
                        free_energy_mean.item(),
                        free_energy_std.item(),
                        entropy_mean.item(),
                        energy_mean.item(),
                        mag_mean.item(),
                        sample_time,
                        train_time,
                        used_time,
                    ))
                sample_time = 0
                train_time = 0

        with open(args.fname, 'a', newline='\n') as f:
            f.write('{} {} {:.3g} {:.8g} {:.8g} {:.8g} {:.8g}\n'.format(
                args.n,
                args.seed,
                beta,
                free_energy_mean.item(),
                free_energy_std.item(),
                energy_mean.item(),
                entropy_mean.item(),
            ))

        if args.ham == 'hop':
            ensure_dir(args.out_filename + '_sample/')
            np.savetxt('{}_sample/sample{:.2f}.txt'.format(
                args.out_filename, beta),
                       sample.cpu().numpy(),
                       delimiter=' ',
                       fmt='%d')
            np.savetxt('{}_sample/log_prob{:.2f}.txt'.format(
                args.out_filename, beta),
                       log_prob.cpu().detach().numpy(),
                       delimiter=' ',
                       fmt='%.5f')

        beta += args.beta_inc
コード例 #7
0
def main():
    start_time = time()
    last_step = get_last_ckpt_step()
    assert last_step >= 0
    my_log(f'Checkpoint found: {last_step}\n')
    print_args()

    net_init, net_apply, net_init_cache, net_apply_fast = get_net()

    params = load_ckpt(last_step)
    in_shape = (args.batch_size, args.L, args.L, 1)
    _, cache_init = net_init_cache(params, jnp.zeros(in_shape), (-1, -1))

    # sample_fun = get_sample_fun(net_apply, None)
    sample_fun = get_sample_fun(net_apply_fast, cache_init)
    log_q_fun = get_log_q_fun(net_apply)

    def sample_energy_fun(rng):
        spins = sample_fun(args.batch_size, params, rng)
        log_q = log_q_fun(params, spins)
        energy = energy_fun(spins)
        return spins, log_q, energy

    @jit
    def update(spins_old, log_q_old, energy_old, step, energy_mean,
               energy_var_sum, rng):
        rng, rng_sample = jrand.split(rng)
        spins, log_q, energy = sample_energy_fun(rng_sample)
        mag = spins.mean(axis=(1, 2, 3))

        step += 1
        energy_per_spin = energy / args.L**2
        energy_mean, energy_var_sum = welford_update(energy_per_spin.mean(),
                                                     step, energy_mean,
                                                     energy_var_sum)

        return (spins, log_q, energy, mag, step, energy_mean, energy_var_sum,
                rng)

    rng, rng_init = jrand.split(jrand.PRNGKey(args.seed))
    spins, log_q, energy = sample_energy_fun(rng_init)

    step = 0
    energy_mean = 0
    energy_var_sum = 0

    data_filename = args.log_filename.replace('.log', '.hdf5')
    writer_proto = [
        # Uncomment to save all the sampled spins
        # ('spins', bool, (args.L, args.L)),
        ('log_q', np.float32, None),
        ('energy', np.int32, None),
        ('mag', np.float32, None),
    ]
    ensure_dir(data_filename)
    with ChunkedDataWriter(data_filename, writer_proto,
                           args.save_step * args.batch_size) as writer:
        my_log('Sampling...')
        while step < args.max_step:
            (spins, log_q, energy, mag, step, energy_mean, energy_var_sum,
             rng) = update(spins, log_q, energy, step, energy_mean,
                           energy_var_sum, rng)
            # Uncomment to save all the sampled spins
            # writer.write_batch(spins[:, :, :, 0] > 0, log_q, energy, mag)
            writer.write_batch(log_q, energy, mag)

            if args.print_step and step % args.print_step == 0:
                energy_std = jnp.sqrt(energy_var_sum / step)
                my_log(', '.join([
                    f'step = {step}',
                    f'E = {energy_mean:.8g}',
                    f'E_std = {energy_std:.8g}',
                    f'time = {time() - start_time:.3f}',
                ]))
コード例 #8
0
#!/usr/bin/python3

import os

from check import check_file, max_missed_pid
from utils import my_log

cdname = os.path.dirname(__file__)
archive_dir = os.path.join(cdname, 'archive', '201810')

if __name__ == '__main__':
    last_pid = None
    for root, dirs, files in os.walk(archive_dir):
        for file in sorted(files):
            filename = os.path.join(root, file)
            oldest_pid, newest_pid = check_file(filename)

            # Check missed posts between files
            if oldest_pid:
                if (last_pid and oldest_pid > last_pid + 1
                        and oldest_pid < last_pid + max_missed_pid):
                    for i in range(last_pid + 1, oldest_pid):
                        my_log('{} {} MISSED BETWEEN FILES'.format(i, file))
                last_pid = oldest_pid
コード例 #9
0
def main():
    start_time = time()
    init_out_dir()
    last_step = get_last_ckpt_step()
    if last_step >= 0:
        my_log(f'\nCheckpoint found: {last_step}\n')
    else:
        clear_log()
    print_args()

    net_init, net_apply, net_init_cache, net_apply_fast = get_net()

    rng, rng_net = jrand.split(jrand.PRNGKey(args.seed))
    in_shape = (args.batch_size, args.L, args.L, 1)
    out_shape, params_init = net_init(rng_net, in_shape)

    _, cache_init = net_init_cache(params_init, jnp.zeros(in_shape), (-1, -1))

    # sample_fun = get_sample_fun(net_apply, None)
    sample_fun = get_sample_fun(net_apply_fast, cache_init)
    log_q_fun = get_log_q_fun(net_apply)

    need_beta_anneal = args.beta_anneal_step > 0

    opt_init, opt_update, get_params = optimizers.adam(args.lr)

    @jit
    def update(step, opt_state, rng):
        params = get_params(opt_state)
        rng, rng_sample = jrand.split(rng)
        spins = sample_fun(args.batch_size, params, rng_sample)
        log_q = log_q_fun(params, spins) / args.L**2
        energy = energy_fun(spins) / args.L**2

        def neg_log_Z_fun(params, spins):
            log_q = log_q_fun(params, spins) / args.L**2
            energy = energy_fun(spins) / args.L**2
            beta = args.beta
            if need_beta_anneal:
                beta *= jnp.minimum(step / args.beta_anneal_step, 1)
            neg_log_Z = log_q + beta * energy
            return neg_log_Z

        loss_fun = partial(expect,
                           log_q_fun,
                           neg_log_Z_fun,
                           mean_grad_expected_is_zero=True)
        grads = grad(loss_fun)(params, spins, spins)
        opt_state = opt_update(step, grads, opt_state)

        return spins, log_q, energy, opt_state, rng

    if last_step >= 0:
        params_init = load_ckpt(last_step)

    opt_state = opt_init(params_init)

    my_log('Training...')
    for step in range(last_step + 1, args.max_step + 1):
        spins, log_q, energy, opt_state, rng = update(step, opt_state, rng)

        if args.print_step and step % args.print_step == 0:
            # Use the final beta, not the annealed beta
            free_energy = log_q / args.beta + energy
            my_log(', '.join([
                f'step = {step}',
                f'F = {free_energy.mean():.8g}',
                f'F_std = {free_energy.std():.8g}',
                f'S = {-log_q.mean():.8g}',
                f'E = {energy.mean():.8g}',
                f'time = {time() - start_time:.3f}',
            ]))

        if args.save_step and step % args.save_step == 0:
            params = get_params(opt_state)
            save_ckpt(params, step)
コード例 #10
0
cdname = os.path.dirname(__file__)
filename = os.path.join(cdname, 'pkuhole.txt')
archive_dir = os.path.join(cdname, 'archive')
archive_basename = 'pkuhole'
archive_extname = '.txt'

day_count = 3

if __name__ == '__main__':
    out_date = date.today() - timedelta(day_count)
    archive_filename = os.path.join(
        archive_dir, out_date.strftime('%Y%m'),
        archive_basename + out_date.strftime('%Y%m%d') + archive_extname)
    if os.path.exists(archive_filename):
        my_log('Archive file exists')
        exit()

    my_log('Archive {}'.format(archive_filename))
    try:
        max_timestamp = int(
            datetime.combine(out_date + timedelta(1),
                             datetime.min.time()).timestamp())
        write_posts(
            archive_filename,
            map(
                get_comment,
                filter(lambda post: post['timestamp'] < max_timestamp,
                       read_posts(filename))))
    except Exception as e:
        my_log('Error: {}'.format(e))
コード例 #11
0
def main():
    start_time = time.time()

    utils.init_out_dir()
    last_epoch = utils.get_last_checkpoint_step()
    if last_epoch >= args.epoch:
        exit()
    if last_epoch >= 0:
        my_log('\nCheckpoint found: {}\n'.format(last_epoch))
    else:
        utils.clear_log()
    utils.print_args()

    model = RNN(args.device, Number_qubits = args.N,charset_length = args.charset_length,\
        hidden_size = args.hidden_size, num_layers = args.num_layers)
    data = prepare_data(args.N, './data.txt')
    ghz = GHZ(Number_qubits=args.N)

    model.train(True)
    my_log('Total nparams: {}'.format(utils.get_nparams(model)))
    model.to(args.device)

    params = [x for x in model.parameters() if x.requires_grad]
    optimizer = torch.optim.AdamW(params,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    if last_epoch >= 0:
        utils.load_checkpoint(last_epoch, model, optimizer)

    init_time = time.time() - start_time
    my_log('init_time = {:.3f}'.format(init_time))

    my_log('Training...')
    start_time = time.time()
    best_fid = 0
    trigger = 0  # once current fid is less than best fid, trigger+=1
    for epoch_idx in range(last_epoch + 1, args.epoch + 1):
        for batch_idx in range(int(args.Ns / args.batch_size)):
            optimizer.zero_grad()
            # idx = np.random.randint(low=0,high=int(args.Ns-1),size=(args.batch_size,))
            idx = np.arange(args.batch_size) + batch_idx * args.batch_size
            train_data = data[idx]
            loss = -model.log_prob(
                torch.from_numpy(train_data).to(args.device)).mean()
            loss.backward()
            if args.clip_grad:
                clip_grad_norm_(params, args.clip_grad)
            optimizer.step()
        print('epoch_idx {} current loss {:.8g}'.format(
            epoch_idx, loss.item()))
        print('Evaluating...')
        # Evaluation
        current_fid = classical_fidelity(model, ghz, print_prob=False)
        if current_fid > best_fid:
            trigger = 0  # reset
            my_log('epoch_idx {} loss {:.8g} fid {} time {:.3f}'.format(
                epoch_idx, loss.item(), current_fid,
                time.time() - start_time))
            best_fid = current_fid
            if (args.out_filename and args.save_epoch
                    and epoch_idx % args.save_epoch == 0):
                state = {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }
                torch.save(
                    state, '{}_save/{}.state'.format(args.out_filename,
                                                     epoch_idx))
        else:
            trigger = trigger + 1
        if trigger > 4:
            break
コード例 #12
0
cdname = os.path.dirname(__file__)
filename = os.path.join(cdname, 'pkuhole.txt')
filename_bak = os.path.join(cdname, 'pkuholebak.txt')


def internet_on():
    try:
        requests.get('https://www.baidu.com', timeout=5)
    except ConnectionError:
        return False
    return True


if __name__ == '__main__':
    if not internet_on():
        my_log('No internet')
        exit()

    if os.path.exists(os.path.join(cdname, 'update.flag')):
        my_log('Update is already running')
        exit()

    with open(os.path.join(cdname, 'update.flag'), 'w', encoding='utf-8') as g:
        g.write(str(os.getpid()))

    my_log('Begin read posts')
    post_dict = read_posts_dict(filename)
    my_log('End read posts')

    my_log('Begin write bak')
    write_posts(filename_bak, post_dict_to_list(post_dict))
コード例 #13
0
def main():
    start_time = time.time()

    utils.init_out_dir()
    last_epoch = utils.get_last_checkpoint_step()
    if last_epoch >= args.epoch:
        exit()
    if last_epoch >= 0:
        my_log('\nCheckpoint found: {}\n'.format(last_epoch))
    else:
        utils.clear_log()
    utils.print_args()

    flow = build_mera()
    flow.train(True)
    my_log('nparams in each RG layer: {}'.format(
        [utils.get_nparams(layer) for layer in flow.layers]))
    my_log('Total nparams: {}'.format(utils.get_nparams(flow)))

    # Use multiple GPUs
    if args.cuda and torch.cuda.device_count() > 1:
        flow = utils.data_parallel_wrap(flow)

    params = [x for x in flow.parameters() if x.requires_grad]
    optimizer = torch.optim.AdamW(params,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)

    if last_epoch >= 0:
        utils.load_checkpoint(last_epoch, flow, optimizer)

    train_split, val_split, data_info = utils.load_dataset()
    train_loader = torch.utils.data.DataLoader(train_split,
                                               args.batch_size,
                                               shuffle=True,
                                               num_workers=1,
                                               pin_memory=True)

    init_time = time.time() - start_time
    my_log('init_time = {:.3f}'.format(init_time))

    my_log('Training...')
    start_time = time.time()
    for epoch_idx in range(last_epoch + 1, args.epoch + 1):
        for batch_idx, (x, _) in enumerate(train_loader):
            optimizer.zero_grad()

            x = x.to(args.device)
            x, ldj_logit = utils.logit_transform(x)
            log_prob = flow.log_prob(x)
            loss = -(log_prob + ldj_logit) / (args.nchannels * args.L**2)
            loss_mean = loss.mean()
            loss_std = loss.std()

            utils.check_nan(loss_mean)

            loss_mean.backward()
            if args.clip_grad:
                clip_grad_norm_(params, args.clip_grad)
            optimizer.step()

            if args.print_step and batch_idx % args.print_step == 0:
                bit_per_dim = (loss_mean.item() + log(256)) / log(2)
                my_log(
                    'epoch {} batch {} bpp {:.8g} loss {:.8g} +- {:.8g} time {:.3f}'
                    .format(
                        epoch_idx,
                        batch_idx,
                        bit_per_dim,
                        loss_mean.item(),
                        loss_std.item(),
                        time.time() - start_time,
                    ))

        if (args.out_filename and args.save_epoch
                and epoch_idx % args.save_epoch == 0):
            state = {
                'flow': flow.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(state,
                       '{}_save/{}.state'.format(args.out_filename, epoch_idx))

            if epoch_idx > 0 and (epoch_idx - 1) % args.keep_epoch != 0:
                os.remove('{}_save/{}.state'.format(args.out_filename,
                                                    epoch_idx - 1))

        if (args.plot_filename and args.plot_epoch
                and epoch_idx % args.plot_epoch == 0):
            with torch.no_grad():
                do_plot(flow, epoch_idx)
コード例 #14
0
def main():
    start_time = time()
    last_step = get_last_ckpt_step()
    assert last_step >= 0
    my_log(f'Checkpoint found: {last_step}\n')
    print_args()

    net_init, net_apply, net_init_cache, net_apply_fast = get_net()

    params = load_ckpt(last_step)
    in_shape = (args.batch_size, args.L, args.L, 1)
    _, cache_init = net_init_cache(params, jnp.zeros(in_shape), (-1, -1))

    # sample_raw_fun = get_sample_fun(net_apply, None)
    sample_raw_fun = get_sample_fun(net_apply_fast, cache_init)
    # sample_k_fun = get_sample_k_fun(net_apply, None)
    sample_k_fun = get_sample_k_fun(net_apply_fast, net_init_cache)
    log_q_fun = get_log_q_fun(net_apply)

    @jit
    def update(spins_old, log_q_old, energy_old, step, accept_count,
               energy_mean, energy_var_sum, rng):
        rng, rng_k, rng_sample, rng_accept = jrand.split(rng, 4)
        k = get_k(rng_k)
        spins = sample_k_fun(k, params, spins_old, rng_sample)
        log_q = log_q_fun(params, spins)
        energy = energy_fun(spins)

        log_uniform = jnp.log(jrand.uniform(rng_accept, (args.batch_size, )))
        accept = log_uniform < (log_q_old - log_q + args.beta *
                                (energy_old - energy))

        spins = jnp.where(jnp.expand_dims(accept, axis=(1, 2, 3)), spins,
                          spins_old)
        log_q = jnp.where(accept, log_q, log_q_old)
        energy = jnp.where(accept, energy, energy_old)
        mag = spins.mean(axis=(1, 2, 3))

        step += 1
        accept_count += accept.sum()
        energy_per_spin = energy / args.L**2
        energy_mean, energy_var_sum = welford_update(energy_per_spin.mean(),
                                                     step, energy_mean,
                                                     energy_var_sum)

        return (spins, log_q, energy, mag, accept, k, step, accept_count,
                energy_mean, energy_var_sum, rng)

    rng, rng_init = jrand.split(jrand.PRNGKey(args.seed))
    # Sample initial configurations from the network
    spins = sample_raw_fun(args.batch_size, params, rng_init)
    log_q = log_q_fun(params, spins)
    energy = energy_fun(spins)

    step = 0
    accept_count = 0
    energy_mean = 0
    energy_var_sum = 0

    data_filename = args.log_filename.replace('.log', '.hdf5')
    writer_proto = [
        # Uncomment to save all the sampled spins
        # ('spins', bool, (args.batch_size, args.L, args.L)),
        ('log_q', np.float32, (args.batch_size, )),
        ('energy', np.int32, (args.batch_size, )),
        ('mag', np.float32, (args.batch_size, )),
        ('accept', bool, (args.batch_size, )),
        ('k', np.int32, None),
    ]
    ensure_dir(data_filename)
    with ChunkedDataWriter(data_filename, writer_proto,
                           args.save_step) as writer:
        my_log('Sampling...')
        while step < args.max_step:
            (spins, log_q, energy, mag, accept, k, step, accept_count,
             energy_mean, energy_var_sum,
             rng) = update(spins, log_q, energy, step, accept_count,
                           energy_mean, energy_var_sum, rng)
            # Uncomment to save all the sampled spins
            # writer.write(spins[:, :, :, 0] > 0, log_q, energy, mag, accept, k)
            writer.write(log_q, energy, mag, accept, k)

            if args.print_step and step % args.print_step == 0:
                accept_rate = accept_count / (step * args.batch_size)
                energy_std = jnp.sqrt(energy_var_sum / step)
                my_log(', '.join([
                    f'step = {step}',
                    f'P = {accept_rate:.8g}',
                    f'E = {energy_mean:.8g}',
                    f'E_std = {energy_std:.8g}',
                    f'time = {time() - start_time:.3f}',
                ]))
コード例 #15
0
def train(dataset_dir, result_dir, batch_size, epochs, lr, load_path,
          save_term):
    # make results dir
    if not os.path.exists(result_dir):
        os.mkdir(result_dir)

    files = os.listdir(result_dir)
    results_number = len(files)

    #set dir name
    desc = 'DCGAN'
    desc += '_batch-%d' % batch_size
    desc += '_epoch-%d' % epochs

    save_dir = os.path.join(result_dir, '%04d_' % (results_number) + desc)
    os.mkdir(save_dir)

    #ckpt dir
    ckpt_dir = os.path.join(save_dir, 'ckpt')
    os.mkdir(ckpt_dir)

    # set my logger
    log = utils.my_log(os.path.join(save_dir, 'results.txt'))
    log.logging('< Info >')

    # load data
    imgs = utils.load_images_in_folder(dataset_dir)
    imgs_num = imgs.shape[0]
    log.logging('dataset path : ' + dataset_dir)
    log.logging('results path : ' + save_dir)
    log.logging('load model path : ' + str(load_path))
    log.logging('load images num : %d' % (imgs_num))
    log.logging('image shape : (%d, %d, %d)' %
                (imgs.shape[1], imgs.shape[2], imgs.shape[3]))

    # images preprocessing [-1 , 1]
    imgs = (imgs - 127.5) / 127.5

    ### train setting
    np.random.seed(2222)
    z_dim = 100
    beta_1 = 0.5
    log.logging('z dim : %d' % z_dim)

    # input placeholder
    g_in = tf.keras.layers.Input(shape=(z_dim), name='G_input')
    d_in = tf.keras.layers.Input(shape=(imgs.shape[1], imgs.shape[2],
                                        imgs.shape[3]),
                                 name='D_input')

    y_true = tf.keras.layers.Input(shape=(1), name='y_true')

    # set model

    G = tf.keras.Model(g_in, model.Generator(g_in), name='Generator')
    D = tf.keras.Model(d_in, model.Discriminator(d_in), name='Discriminator')
    GAN = tf.keras.Model(g_in, D(G(g_in)), name='DCGAN')

    #G.summary()
    #D.summary()
    #GAN.summary()

    # set optimizer
    G_opt = tf.keras.optimizers.Adam(learning_rate=lr,
                                     beta_1=beta_1,
                                     name='Adam_G')
    D_opt = tf.keras.optimizers.Adam(learning_rate=lr,
                                     beta_1=beta_1,
                                     name='Adam_D')

    log.logging('G opt : Adam(lr:%f, beta_1:%f)' % (lr, beta_1))
    log.logging('D opt : Adam(lr:%f, beta_1:%f)' % (lr, beta_1))
    log.logging('total epoch : %d' % epochs)
    log.logging('batch size  : %d' % batch_size)

    log.logging(utils.SPLIT_LINE, log_only=True)

    # train
    log.logging('< train >')

    train_start_time = time.time()

    # test noise z
    z_test = np.random.normal(size=(batch_size, z_dim))

    # load model
    if load_path is not None:
        GAN.load_weights(load_path)
        log.logging('[' + load_path + '] model loaded !!')

    for epoch in range(1, epochs + 1):
        start_epoch_time = time.time()
        epoch_g_loss = 0
        epoch_d_loss = 0
        # remaining data is not used
        for step in range(imgs_num // batch_size):

            # training D
            with tf.GradientTape() as tape_fake, tf.GradientTape(
            ) as tape_real:
                z = np.random.normal(size=[batch_size, z_dim])
                real_imgs = imgs[step * batch_size:(step + 1) * batch_size]
                fake_imgs = G(z, training=True)

                fake_logits = D(fake_imgs, training=True)
                real_logits = D(real_imgs, training=True)

                d_loss_fake = loss.loss(np.zeros((batch_size, 1)), fake_logits)
                d_loss_real = loss.loss(np.ones((batch_size, 1)), real_logits)

            d_gradient_fake = tape_fake.gradient(d_loss_fake,
                                                 D.trainable_variables)
            d_gradient_real = tape_real.gradient(d_loss_real,
                                                 D.trainable_variables)
            D_opt.apply_gradients(zip(d_gradient_fake, D.trainable_variables))
            D_opt.apply_gradients(zip(d_gradient_real, D.trainable_variables))

            # training G
            with tf.GradientTape() as tape:
                z = np.random.normal(size=[batch_size, z_dim])
                fake_imgs = G(z, training=True)
                g_logits = D(fake_imgs, training=True)
                g_loss = loss.loss(np.ones((batch_size, 1)), g_logits)

            g_gradient = tape.gradient(g_loss, G.trainable_variables)
            G_opt.apply_gradients(zip(g_gradient, G.trainable_variables))

            # calculate loss
            loss_g = g_loss

            loss_d = d_loss_real + d_loss_fake

            epoch_g_loss += loss_g
            epoch_d_loss += loss_d

            print('%03d / %03d loss (G : %f || D : %f ) detail:(f:%f r:%f)' %
                  (step, imgs_num // batch_size, loss_g, loss_d, d_loss_fake,
                   d_loss_real))

        epoch_g_loss /= (imgs_num // batch_size)
        epoch_d_loss /= (imgs_num // batch_size)

        log.logging(
            '[%d/%d] epoch << G loss: %.5f || D loss: %.5f >>  time: %.1f sec'
            % (epoch, epochs, epoch_g_loss, epoch_d_loss,
               time.time() - start_epoch_time))

        # make fake imgs
        fake_imgs = get_grid_imgs(G, batch_size, z_dim, z_test)
        cv2.imwrite(os.path.join(save_dir, 'fake %05depoc.png' % epoch),
                    fake_imgs)

        # model save
        if epoch % save_term == 0:
            save_name = os.path.join(ckpt_dir, 'model-%d.h5' % (epoch))
            GAN.save(save_name)

    log.logging(
        '\n[%d] epoch finish! << fianl G loss: %.5f || final D loss: %.5f >>  total time: %.1f sec'
        % (epochs, epoch_g_loss, epoch_d_loss, time.time() - train_start_time))

    save_name = os.path.join(save_dir, 'model-%d.h5' % (epoch))
    GAN.save(save_name)
    print('train finished!')
コード例 #16
0
ファイル: splitall.py プロジェクト: lyutyuh/dlpkuhole2
import os
from datetime import date

from utils import my_log, post_dict_to_list, read_posts, write_posts

cdname = os.path.dirname(__file__)
filename = os.path.join(cdname, 'pkuhole.txt')
archive_dir = os.path.join(cdname, 'archivebak')
archive_basename = 'pkuhole'
archive_extname = '.txt'

if __name__ == '__main__':
    post_list = read_posts(filename)
    last_date = date.fromtimestamp(post_list[0]['timestamp'])
    now_post_dict = {}
    for post in post_list:
        now_date = date.fromtimestamp(post['timestamp'])
        if now_date < last_date:
            archive_filename = os.path.join(
                archive_dir, last_date.strftime('%Y%m'), archive_basename +
                last_date.strftime('%Y%m%d') + archive_extname)
            write_posts(archive_filename, post_dict_to_list(now_post_dict))
            last_date = now_date
            now_post_dict = {}
            my_log(last_date.strftime('%Y%m%d'))
        now_post_dict[post['pid']] = post
    archive_filename = os.path.join(
        archive_dir, last_date.strftime('%Y%m'),
        archive_basename + last_date.strftime('%Y%m%d') + archive_extname)
    write_posts(archive_filename, post_dict_to_list(now_post_dict))
コード例 #17
0
ファイル: main_xy.py プロジェクト: Hikikomori-1117/CAN2XY
def main():

    start_time = time.time()
    # initialize output dir
    init_out_dir()
    # check point
    if args.clear_checkpoint:
        clear_checkpoint()
    last_step = get_last_checkpoint_step()
    if last_step >= 0:
        my_log('\nCheckpoint found: {}\n'.format(last_step))
    else:
        clear_log()
    print_args()

    if args.net == 'pixelcnn_xy':
        net = PixelCNN(**vars(args))
    else:
        raise ValueError('Unknown net: {}'.format(args.net))
    net.to(args.device)
    my_log('{}\n'.format(net))

    # parameters of networks
    params = list(net.parameters())
    params = list(filter(lambda p: p.requires_grad,
                         params))  # parameters with gradients
    nparams = int(sum([np.prod(p.shape) for p in params]))
    my_log('Total number of trainable parameters: {}'.format(nparams))
    named_params = list(net.named_parameters())
    # optimizers
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(params, lr=args.lr)
    elif args.optimizer == 'sgdm':
        optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(params, lr=args.lr, alpha=0.99)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.9, 0.999))
    elif args.optimizer == 'adam0.5':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.5, 0.999))
    else:
        raise ValueError('Unknown optimizer: {}'.format(args.optimizer))
    # learning rates
    if args.lr_schedule:
        # 0.92**80 ~ 1e-3
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               factor=0.92,
                                                               patience=100,
                                                               threshold=1e-4,
                                                               min_lr=1e-6)
    # read last step
    if last_step >= 0:
        state = torch.load('{}_save/{}.state'.format(args.out_filename,
                                                     last_step))
        ignore_param(state['net'], net)
        net.load_state_dict(state['net'])
        if state.get('optimizer'):
            optimizer.load_state_dict(state['optimizer'])
        if args.lr_schedule and state.get('scheduler'):
            scheduler.load_state_dict(state['scheduler'])

    init_time = time.time() - start_time
    my_log('init_time = {:.3f}'.format(init_time))

    # start training
    my_log('Training...')
    sample_time = 0
    train_time = 0
    start_time = time.time()
    for step in range(last_step + 1, args.max_step + 1):
        optimizer.zero_grad()  # clear last step

        sample_start_time = time.time()
        with torch.no_grad():
            sample, x_hat = net.sample(
                args.batch_size
            )  # sample from networks with batch_size = 10**3 (default)
        assert not sample.requires_grad
        assert not x_hat.requires_grad
        sample_time += time.time() - sample_start_time

        train_start_time = time.time()

        # log probabilities
        log_prob = net.log_prob(sample, args.batch_size)

        # 0.998**9000 ~ 1e-8
        beta = args.beta * (1 - args.beta_anneal**step
                            )  # anneal process to avoid mode collapse
        with torch.no_grad():
            energy, vortices = xy.energy(sample, args.ham, args.lattice,
                                         args.boundary)
            loss = log_prob + beta * energy  # construct loss function(free energy)from configurations

        assert not energy.requires_grad
        assert not loss.requires_grad
        loss_reinforce = torch.mean((loss - loss.mean()) * log_prob)
        loss_reinforce.backward()  # back propagation

        if args.clip_grad:
            nn.utils.clip_grad_norm_(params, args.clip_grad)

        optimizer.step()

        if args.lr_schedule:
            scheduler.step(loss.mean())

        train_time += time.time() - train_start_time

        # export physical observables
        if args.print_step and step % args.print_step == 0:
            free_energy_mean = loss.mean() / beta / (args.L**2
                                                     )  # free energy density
            free_energy_std = loss.std() / beta / (args.L**2)
            entropy_mean = -log_prob.mean() / (args.L**2)  # entropy density
            energy_mean = (energy / (args.L**2)).mean()  # energy density
            energy_std = (energy / (args.L**2)).std()
            vortices = vortices.mean() / args.L**2  # vortices density

            # heat_capacity=(((energy/ (args.L**2))**2).mean()- ((energy/ (args.L**2)).mean())**2)  *(beta**2)

            # magnetization
            # mag = torch.cos(sample).sum(dim=(2,3)).mean(dim=0) # M_x (M_x,M_y)=(cos(theta), sin(theta))
            # mag_mean = mag.mean()
            # mag_sqr_mean = (mag**2).mean()
            # sus_mean = mag_sqr_mean/args.L**2

            # log
            if step > 0:
                sample_time /= args.print_step
                train_time /= args.print_step
            used_time = time.time() - start_time
            # hyperparameters in training
            my_log(
                'step = {}, lr = {:.3g}, loss={:.8g}, beta = {:.8g}, sample_time = {:.3f}, train_time = {:.3f}, used_time = {:.3f}'
                .format(
                    step,
                    optimizer.param_groups[0]['lr'],
                    loss.mean(),
                    beta,
                    sample_time,
                    train_time,
                    used_time,
                ))
            # observables
            my_log(
                'F = {:.8g}, F_std = {:.8g}, E = {:.8g}, E_std={:.8g}, v={:.8g}'
                .format(
                    free_energy_mean.item(),
                    free_energy_std.item(),
                    energy_mean.item(),
                    energy_std.item(),
                    vortices.item(),
                ))

            sample_time = 0
            train_time = 0
            # save sample
            if args.save_sample and step % args.save_step == 0:
                # save traning state
                # state = {
                #     'sample': sample,
                #     'x_hat': x_hat,
                #     'log_prob': log_prob,
                #     'energy': energy,
                #     'loss': loss,
                # }
                # torch.save(state, '{}_save/{}.sample'.format(
                #     args.out_filename, step))

                # Recognize the Phase Transition
                # helicity
                with torch.no_grad():
                    correlations = helicity(sample)
                helicity_modulus = -((energy / args.L**2).mean()) - (
                    args.beta * correlations**2 / args.L**2).mean()
                my_log('Rho={:.8g}'.format(helicity_modulus.item()))

                # save configurations
                sample_array = sample.cpu().numpy()
                np.savetxt(
                    '{}_save/sample{}.txt'.format(args.out_filename, step),
                    sample_array.reshape(args.batch_size, -1))
                # save observables
                np.savetxt(
                    '{}_save/results{}.csv'.format(args.out_filename, step), [
                        beta,
                        step,
                        free_energy_mean,
                        free_energy_std,
                        energy_mean,
                        energy_std,
                        vortices,
                        helicity_modulus,
                    ])

        # save net
        if (args.out_filename and args.save_step
                and step % args.save_step == 0):
            state = {
                'net': net.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            if args.lr_schedule:
                state['scheduler'] = scheduler.state_dict()
            torch.save(state,
                       '{}_save/{}.state'.format(args.out_filename, step))

        # visualization in each visual_step
        if (args.out_filename and args.visual_step
                and step % args.visual_step == 0):
            # torchvision.utils.save_image(
            #     sample,
            #     '{}_img/{}.png'.format(args.out_filename, step),
            #     nrow=int(sqrt(sample.shape[0])),
            #     padding=0,
            #     normalize=True)
            # print sample
            if args.print_sample:
                x_hat_alpha = x_hat[:, 0, :, :].view(x_hat.shape[0],
                                                     -1).cpu().numpy()  # alpha
                x_hat_std1 = np.std(x_hat_alpha, axis=0).reshape([args.L] * 2)
                x_hat_beta = x_hat[:, 1, :, :].view(x_hat.shape[0],
                                                    -1).cpu().numpy()  # beta
                x_hat_std2 = np.std(x_hat_beta, axis=0).reshape([args.L] * 2)

                energy_np = energy.cpu().numpy()
                energy_count = np.stack(
                    np.unique(energy_np, return_counts=True)).T

                my_log(
                    '\nsample\n{}\nalpha\n{}\nbeta\n{}\nlog_prob\n{}\nenergy\n{}\nloss\n{}\nalpha_std\n{}\nbeta_std\n{}\nenergy_count\n{}\n'
                    .format(
                        sample[:args.print_sample, 0],
                        x_hat[:args.print_sample, 0],
                        x_hat[:args.print_sample, 1],
                        log_prob[:args.print_sample],
                        energy[:args.print_sample],
                        loss[:args.print_sample],
                        x_hat_std1,
                        x_hat_std2,
                        energy_count,
                    ))
            # print gradient
            if args.print_grad:
                my_log('grad max_abs min_abs mean std')
                for name, param in named_params:
                    if param.grad is not None:
                        grad = param.grad
                        grad_abs = torch.abs(grad)
                        my_log('{} {:.3g} {:.3g} {:.3g} {:.3g}'.format(
                            name,
                            torch.max(grad_abs).item(),
                            torch.min(grad_abs).item(),
                            torch.mean(grad).item(),
                            torch.std(grad).item(),
                        ))
                    else:
                        my_log('{} None'.format(name))
                my_log('')
コード例 #18
0
ファイル: main.py プロジェクト: wdphy16/stat-mech-van
def main():
    start_time = time.time()

    init_out_dir()
    if args.clear_checkpoint:
        clear_checkpoint()
    last_step = get_last_checkpoint_step()
    if last_step >= 0:
        my_log('\nCheckpoint found: {}\n'.format(last_step))
    else:
        clear_log()
    print_args()

    if args.net == 'made':
        net = MADE(**vars(args))
    elif args.net == 'pixelcnn':
        net = PixelCNN(**vars(args))
    elif args.net == 'bernoulli':
        net = BernoulliMixture(**vars(args))
    else:
        raise ValueError('Unknown net: {}'.format(args.net))
    net.to(args.device)
    my_log('{}\n'.format(net))

    params = list(net.parameters())
    params = list(filter(lambda p: p.requires_grad, params))
    nparams = int(sum([np.prod(p.shape) for p in params]))
    my_log('Total number of trainable parameters: {}'.format(nparams))
    named_params = list(net.named_parameters())

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(params, lr=args.lr)
    elif args.optimizer == 'sgdm':
        optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(params, lr=args.lr, alpha=0.99)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.9, 0.999))
    elif args.optimizer == 'adam0.5':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.5, 0.999))
    else:
        raise ValueError('Unknown optimizer: {}'.format(args.optimizer))

    if args.lr_schedule:
        # 0.92**80 ~ 1e-3
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, factor=0.92, patience=100, threshold=1e-4, min_lr=1e-6)

    if last_step >= 0:
        state = torch.load('{}_save/{}.state'.format(args.out_filename,
                                                     last_step))
        ignore_param(state['net'], net)
        net.load_state_dict(state['net'])
        if state.get('optimizer'):
            optimizer.load_state_dict(state['optimizer'])
        if args.lr_schedule and state.get('scheduler'):
            scheduler.load_state_dict(state['scheduler'])

    init_time = time.time() - start_time
    my_log('init_time = {:.3f}'.format(init_time))

    my_log('Training...')
    sample_time = 0
    train_time = 0
    start_time = time.time()
    for step in range(last_step + 1, args.max_step + 1):
        optimizer.zero_grad()

        sample_start_time = time.time()
        with torch.no_grad():
            sample, x_hat = net.sample(args.batch_size)
        assert not sample.requires_grad
        assert not x_hat.requires_grad
        sample_time += time.time() - sample_start_time

        train_start_time = time.time()

        log_prob = net.log_prob(sample)
        # 0.998**9000 ~ 1e-8
        beta = args.beta * (1 - args.beta_anneal**step)
        with torch.no_grad():
            energy = ising.energy(sample, args.ham, args.lattice,
                                  args.boundary)
            loss = log_prob + beta * energy
        assert not energy.requires_grad
        assert not loss.requires_grad
        loss_reinforce = torch.mean((loss - loss.mean()) * log_prob)
        loss_reinforce.backward()

        if args.clip_grad:
            nn.utils.clip_grad_norm_(params, args.clip_grad)

        optimizer.step()

        if args.lr_schedule:
            scheduler.step(loss.mean())

        train_time += time.time() - train_start_time

        if args.print_step and step % args.print_step == 0:
            free_energy_mean = loss.mean() / args.beta / args.L**2
            free_energy_std = loss.std() / args.beta / args.L**2
            entropy_mean = -log_prob.mean() / args.L**2
            energy_mean = energy.mean() / args.L**2
            mag = sample.mean(dim=0)
            mag_mean = mag.mean()
            mag_sqr_mean = (mag**2).mean()
            if step > 0:
                sample_time /= args.print_step
                train_time /= args.print_step
            used_time = time.time() - start_time
            my_log(
                'step = {}, F = {:.8g}, F_std = {:.8g}, S = {:.8g}, E = {:.8g}, M = {:.8g}, Q = {:.8g}, lr = {:.3g}, beta = {:.8g}, sample_time = {:.3f}, train_time = {:.3f}, used_time = {:.3f}'
                .format(
                    step,
                    free_energy_mean.item(),
                    free_energy_std.item(),
                    entropy_mean.item(),
                    energy_mean.item(),
                    mag_mean.item(),
                    mag_sqr_mean.item(),
                    optimizer.param_groups[0]['lr'],
                    beta,
                    sample_time,
                    train_time,
                    used_time,
                ))
            sample_time = 0
            train_time = 0

            if args.save_sample:
                state = {
                    'sample': sample,
                    'x_hat': x_hat,
                    'log_prob': log_prob,
                    'energy': energy,
                    'loss': loss,
                }
                torch.save(state, '{}_save/{}.sample'.format(
                    args.out_filename, step))

        if (args.out_filename and args.save_step
                and step % args.save_step == 0):
            state = {
                'net': net.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            if args.lr_schedule:
                state['scheduler'] = scheduler.state_dict()
            torch.save(state, '{}_save/{}.state'.format(
                args.out_filename, step))

        if (args.out_filename and args.visual_step
                and step % args.visual_step == 0):
            torchvision.utils.save_image(
                sample,
                '{}_img/{}.png'.format(args.out_filename, step),
                nrow=int(sqrt(sample.shape[0])),
                padding=0,
                normalize=True)

            if args.print_sample:
                x_hat_np = x_hat.view(x_hat.shape[0], -1).cpu().numpy()
                x_hat_std = np.std(x_hat_np, axis=0).reshape([args.L] * 2)

                x_hat_cov = np.cov(x_hat_np.T)
                x_hat_cov_diag = np.diag(x_hat_cov)
                x_hat_corr = x_hat_cov / (
                    sqrt(x_hat_cov_diag[:, None] * x_hat_cov_diag[None, :]) +
                    args.epsilon)
                x_hat_corr = np.tril(x_hat_corr, -1)
                x_hat_corr = np.max(np.abs(x_hat_corr), axis=1)
                x_hat_corr = x_hat_corr.reshape([args.L] * 2)

                energy_np = energy.cpu().numpy()
                energy_count = np.stack(
                    np.unique(energy_np, return_counts=True)).T

                my_log(
                    '\nsample\n{}\nx_hat\n{}\nlog_prob\n{}\nenergy\n{}\nloss\n{}\nx_hat_std\n{}\nx_hat_corr\n{}\nenergy_count\n{}\n'
                    .format(
                        sample[:args.print_sample, 0],
                        x_hat[:args.print_sample, 0],
                        log_prob[:args.print_sample],
                        energy[:args.print_sample],
                        loss[:args.print_sample],
                        x_hat_std,
                        x_hat_corr,
                        energy_count,
                    ))

            if args.print_grad:
                my_log('grad max_abs min_abs mean std')
                for name, param in named_params:
                    if param.grad is not None:
                        grad = param.grad
                        grad_abs = torch.abs(grad)
                        my_log('{} {:.3g} {:.3g} {:.3g} {:.3g}'.format(
                            name,
                            torch.max(grad_abs).item(),
                            torch.min(grad_abs).item(),
                            torch.mean(grad).item(),
                            torch.std(grad).item(),
                        ))
                    else:
                        my_log('{} None'.format(name))
                my_log('')
コード例 #19
0
def main():
    start_time = time.time()

    init_out_dir()
    print_args()

    fsave_name = 'n{}b{:.2f}D{}.pickle'.format(args.n, args.beta, args.seed)
    with open(fsave_name, 'rb') as f:
        sk_true = pickle.load(f)

    sk_true.J = sk_true.J.to(args.device)
    assert args.n == sk_true.n
    assert args.beta == sk_true.beta

    sk = SKModel(args.n, args.beta, args.device, seed=args.seed)
    my_log('diff = {:.8g}'.format(sk_true.J_diff(sk.J)))

    C_data = sk_true.C_model.to(args.device, default_dtype_torch)
    C = C_data
    C_inv = torch.inverse(C)

    J_nmf = (torch.eye(args.n).to(C.device) - C_inv) / args.beta
    idx = range(args.n)
    J_nmf[idx, idx] = 0
    diff_nmf = sk_true.J_diff(J_nmf.to(args.device))
    my_log('diff_NMF = {:.8g}'.format(diff_nmf))

    J_IP = 1 / 2 * torch.log((1 + C) / (1 - C + args.epsilon)) / args.beta
    idx = range(args.n)
    J_IP[idx, idx] = 0
    diff_IP = sk_true.J_diff(J_IP.to(args.device))
    my_log('diff_IP = {:.8g}'.format(diff_IP))

    J_SM = -C_inv + J_IP * args.beta - C / (1 - C**2 + args.epsilon)
    J_SM /= args.beta
    idx = range(args.n)
    J_SM[idx, idx] = 0
    diff_SM = sk_true.J_diff(J_SM.to(args.device))
    my_log('diff_SM = {:.8g}'.format(diff_SM))

    sk.J.data[:] = 0

    net = MADE(**vars(args))
    net.to(args.device)
    my_log('{}\n'.format(net))

    params = list(net.parameters())
    params = [sk.J]
    params = list(filter(lambda p: p.requires_grad, params))
    nparams = int(sum([np.prod(p.shape) for p in params]))
    my_log('Total number of trainable parameters: {}'.format(nparams))
    params2 = [sk.J]

    optimizer = torch.optim.Adam(params, lr=args.lr)
    optimizer2 = torch.optim.Adam(params2, lr=0.01)

    init_time = time.time() - start_time
    my_log('init_time = {:.3f}'.format(init_time))

    my_log('Training...')
    sample_time = 0
    train_time = 0
    start_time = time.time()
    diff = 10000
    for step2 in range(5000):
        optimizer2.zero_grad()
        for step in range(args.max_step + 1):
            optimizer.zero_grad()

            sample_start_time = time.time()
            with torch.no_grad():
                sample, x_hat = net.sample(args.batch_size)
            assert not sample.requires_grad
            assert not x_hat.requires_grad
            sample_time += time.time() - sample_start_time

            train_start_time = time.time()

            log_prob = net.log_prob(sample)
            with torch.no_grad():
                energy = sk.energy(sample)
                loss = log_prob + args.beta * energy
            assert not energy.requires_grad
            assert not loss.requires_grad
            loss_reinforce = torch.mean((loss - loss.mean()) * log_prob)
            loss_reinforce.backward()

            optimizer.step()

            train_time += time.time() - train_start_time

        with torch.no_grad():
            sample, x_hat = net.sample(10**6)
        assert not sample.requires_grad
        assert not x_hat.requires_grad

        m_model = torch.mean(sample, 0).view(sample.shape[1], 1)
        C_model = sample.t() @ sample / sample.shape[0] - m_model @ m_model.t()

        C_diff = C_data - C_model
        C_norm = C_diff.norm(2)

        sk.J.grad = -C_diff / (C_norm + args.epsilon)
        optimizer2.step()

        diff_old = diff
        diff = sk_true.J_diff(sk.J)
        my_log('# {}, diff_J = {}, diff_C = {}'.format(
            step2, diff, torch.sqrt(torch.mean(C_diff**2))))

        idx = range(args.n)
        sk.J.data[idx, idx] = 0

    with open(args.fname, 'a', newline='\n') as f:
        f.write('{} {} {:.3g} {:.8g}\n'.format(args.n, args.seed, args.beta,
                                               diff))