def main():
    utils.init_out_dir()
    last_epoch = utils.get_last_checkpoint_step()
    if last_epoch >= args.epoch:
        exit()
    if last_epoch >= 0:
        my_log('\nCheckpoint found: {}\n'.format(last_epoch))
    else:
        utils.clear_log()

    model = RNN(args.device, Number_qubits = args.N,charset_length = args.charset_length,\
        hidden_size = args.hidden_size, num_layers = args.num_layers)

    model.train(False)
    print('number of qubits: ', model.Number_qubits)
    my_log('Total nparams: {}'.format(utils.get_nparams(model)))
    model.to(args.device)
    params = [x for x in model.parameters() if x.requires_grad]
    optimizer = torch.optim.AdamW(params,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    if last_epoch >= 0:
        utils.load_checkpoint(last_epoch, model, optimizer)

    # Quantum state
    ghz = GHZ(Number_qubits=args.N)
    c_fidelity = classical_fidelity(model, ghz)
    # c_fidelity = cfid(model, ghz, './data.txt')
    print(c_fidelity)
Пример #2
0
def main():
    start_time = time.time()

    init_out_dir()
    if args.clear_checkpoint:
        clear_checkpoint()
    last_step = get_last_checkpoint_step()
    if last_step >= 0:
        my_log('\nCheckpoint found: {}\n'.format(last_step))
    else:
        clear_log()
    print_args()

    if args.net == 'made':
        net = MADE(**vars(args))
    elif args.net == 'pixelcnn':
        net = PixelCNN(**vars(args))
    elif args.net == 'bernoulli':
        net = BernoulliMixture(**vars(args))
    else:
        raise ValueError('Unknown net: {}'.format(args.net))
    net.to(args.device)
    my_log('{}\n'.format(net))

    params = list(net.parameters())
    params = list(filter(lambda p: p.requires_grad, params))
    nparams = int(sum([np.prod(p.shape) for p in params]))
    my_log('Total number of trainable parameters: {}'.format(nparams))
    named_params = list(net.named_parameters())

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(params, lr=args.lr)
    elif args.optimizer == 'sgdm':
        optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(params, lr=args.lr, alpha=0.99)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.9, 0.999))
    elif args.optimizer == 'adam0.5':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.5, 0.999))
    else:
        raise ValueError('Unknown optimizer: {}'.format(args.optimizer))

    if args.lr_schedule:
        # 0.92**80 ~ 1e-3
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, factor=0.92, patience=100, threshold=1e-4, min_lr=1e-6)

    if last_step >= 0:
        state = torch.load('{}_save/{}.state'.format(args.out_filename,
                                                     last_step))
        ignore_param(state['net'], net)
        net.load_state_dict(state['net'])
        if state.get('optimizer'):
            optimizer.load_state_dict(state['optimizer'])
        if args.lr_schedule and state.get('scheduler'):
            scheduler.load_state_dict(state['scheduler'])

    init_time = time.time() - start_time
    my_log('init_time = {:.3f}'.format(init_time))

    my_log('Training...')
    sample_time = 0
    train_time = 0
    start_time = time.time()
    for step in range(last_step + 1, args.max_step + 1):
        optimizer.zero_grad()

        sample_start_time = time.time()
        with torch.no_grad():
            sample, x_hat = net.sample(args.batch_size)
        assert not sample.requires_grad
        assert not x_hat.requires_grad
        sample_time += time.time() - sample_start_time

        train_start_time = time.time()

        log_prob = net.log_prob(sample)
        # 0.998**9000 ~ 1e-8
        beta = args.beta * (1 - args.beta_anneal**step)
        with torch.no_grad():
            energy = ising.energy(sample, args.ham, args.lattice,
                                  args.boundary)
            loss = log_prob + beta * energy
        assert not energy.requires_grad
        assert not loss.requires_grad
        loss_reinforce = torch.mean((loss - loss.mean()) * log_prob)
        loss_reinforce.backward()

        if args.clip_grad:
            nn.utils.clip_grad_norm_(params, args.clip_grad)

        optimizer.step()

        if args.lr_schedule:
            scheduler.step(loss.mean())

        train_time += time.time() - train_start_time

        if args.print_step and step % args.print_step == 0:
            free_energy_mean = loss.mean() / args.beta / args.L**2
            free_energy_std = loss.std() / args.beta / args.L**2
            entropy_mean = -log_prob.mean() / args.L**2
            energy_mean = energy.mean() / args.L**2
            mag = sample.mean(dim=0)
            mag_mean = mag.mean()
            mag_sqr_mean = (mag**2).mean()
            if step > 0:
                sample_time /= args.print_step
                train_time /= args.print_step
            used_time = time.time() - start_time
            my_log(
                'step = {}, F = {:.8g}, F_std = {:.8g}, S = {:.8g}, E = {:.8g}, M = {:.8g}, Q = {:.8g}, lr = {:.3g}, beta = {:.8g}, sample_time = {:.3f}, train_time = {:.3f}, used_time = {:.3f}'
                .format(
                    step,
                    free_energy_mean.item(),
                    free_energy_std.item(),
                    entropy_mean.item(),
                    energy_mean.item(),
                    mag_mean.item(),
                    mag_sqr_mean.item(),
                    optimizer.param_groups[0]['lr'],
                    beta,
                    sample_time,
                    train_time,
                    used_time,
                ))
            sample_time = 0
            train_time = 0

            if args.save_sample:
                state = {
                    'sample': sample,
                    'x_hat': x_hat,
                    'log_prob': log_prob,
                    'energy': energy,
                    'loss': loss,
                }
                torch.save(state, '{}_save/{}.sample'.format(
                    args.out_filename, step))

        if (args.out_filename and args.save_step
                and step % args.save_step == 0):
            state = {
                'net': net.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            if args.lr_schedule:
                state['scheduler'] = scheduler.state_dict()
            torch.save(state, '{}_save/{}.state'.format(
                args.out_filename, step))

        if (args.out_filename and args.visual_step
                and step % args.visual_step == 0):
            torchvision.utils.save_image(
                sample,
                '{}_img/{}.png'.format(args.out_filename, step),
                nrow=int(sqrt(sample.shape[0])),
                padding=0,
                normalize=True)

            if args.print_sample:
                x_hat_np = x_hat.view(x_hat.shape[0], -1).cpu().numpy()
                x_hat_std = np.std(x_hat_np, axis=0).reshape([args.L] * 2)

                x_hat_cov = np.cov(x_hat_np.T)
                x_hat_cov_diag = np.diag(x_hat_cov)
                x_hat_corr = x_hat_cov / (
                    sqrt(x_hat_cov_diag[:, None] * x_hat_cov_diag[None, :]) +
                    args.epsilon)
                x_hat_corr = np.tril(x_hat_corr, -1)
                x_hat_corr = np.max(np.abs(x_hat_corr), axis=1)
                x_hat_corr = x_hat_corr.reshape([args.L] * 2)

                energy_np = energy.cpu().numpy()
                energy_count = np.stack(
                    np.unique(energy_np, return_counts=True)).T

                my_log(
                    '\nsample\n{}\nx_hat\n{}\nlog_prob\n{}\nenergy\n{}\nloss\n{}\nx_hat_std\n{}\nx_hat_corr\n{}\nenergy_count\n{}\n'
                    .format(
                        sample[:args.print_sample, 0],
                        x_hat[:args.print_sample, 0],
                        log_prob[:args.print_sample],
                        energy[:args.print_sample],
                        loss[:args.print_sample],
                        x_hat_std,
                        x_hat_corr,
                        energy_count,
                    ))

            if args.print_grad:
                my_log('grad max_abs min_abs mean std')
                for name, param in named_params:
                    if param.grad is not None:
                        grad = param.grad
                        grad_abs = torch.abs(grad)
                        my_log('{} {:.3g} {:.3g} {:.3g} {:.3g}'.format(
                            name,
                            torch.max(grad_abs).item(),
                            torch.min(grad_abs).item(),
                            torch.mean(grad).item(),
                            torch.std(grad).item(),
                        ))
                    else:
                        my_log('{} None'.format(name))
                my_log('')
Пример #3
0
def main():
    start_time = time.time()

    init_out_dir()
    print_args()
    args.n = 60  ##n=60

    if args.ham == 'hop':
        ham = HopfieldModel(args.n, args.beta, args.device, seed=args.seed)
    elif args.ham == 'sk':
        ham = SKModel(args.n, args.beta, args.device, seed=args.seed)
    else:
        raise ValueError('Unknown ham: {}'.format(args.ham))
    ham.J.requires_grad = False

    net = MADE(**vars(args))
    net.to(args.device)
    my_log('{}\n'.format(net))

    params = list(net.parameters())
    params = list(filter(lambda p: p.requires_grad, params))
    nparams = int(sum([np.prod(p.shape) for p in params]))
    my_log('Total number of trainable parameters: {}'.format(nparams))

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(params, lr=args.lr)
    elif args.optimizer == 'sgdm':
        optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(params, lr=args.lr, alpha=0.99)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.9, 0.999))
    elif args.optimizer == 'adam0.5':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.5, 0.999))
    else:
        raise ValueError('Unknown optimizer: {}'.format(args.optimizer))

    init_time = time.time() - start_time
    my_log('init_time = {:.3f}'.format(init_time))

    my_log('Training...')
    sample_time = 0
    train_time = 0
    start_time = time.time()
    if args.beta_anneal_to < args.beta:
        args.beta_anneal_to = args.beta
    beta = args.beta
    while beta <= args.beta_anneal_to:
        for step in range(args.max_step):
            optimizer.zero_grad()

            sample_start_time = time.time()
            with torch.no_grad():
                sample, x_hat = net.sample(args.batch_size)
            assert not sample.requires_grad
            assert not x_hat.requires_grad
            sample_time += time.time() - sample_start_time

            train_start_time = time.time()

            log_prob = net.log_prob(sample)
            with torch.no_grad():
                energy = ham.energy(sample)
                #print('fffffffffffffffffff',energy.type())
                loss = log_prob + beta * energy
            assert not energy.requires_grad
            assert not loss.requires_grad
            loss_reinforce = torch.mean((loss - loss.mean()) * log_prob)
            loss_reinforce.backward()

            if args.clip_grad > 0:
                # nn.utils.clip_grad_norm_(params, args.clip_grad)
                parameters = list(filter(lambda p: p.grad is not None, params))
                max_norm = float(args.clip_grad)
                norm_type = 2
                total_norm = 0
                for p in parameters:
                    param_norm = p.grad.data.norm(norm_type)
                    total_norm += param_norm.item()**norm_type
                    total_norm = total_norm**(1 / norm_type)
                    clip_coef = max_norm / (total_norm + args.epsilon)
                    for p in parameters:
                        p.grad.data.mul_(clip_coef)

            optimizer.step()

            train_time += time.time() - train_start_time

            if args.print_step and step % args.print_step == 0:
                free_energy_mean = loss.mean() / beta / args.n
                free_energy_std = loss.std() / beta / args.n
                entropy_mean = -log_prob.mean() / args.n
                energy_mean = energy.mean() / args.n
                mag = sample.mean(dim=0)
                mag_mean = mag.mean()
                if step > 0:
                    sample_time /= args.print_step
                    train_time /= args.print_step
                used_time = time.time() - start_time
                my_log(
                    'beta = {:.3g}, # {}, F = {:.8g}, F_std = {:.8g}, S = {:.5g}, E = {:.5g}, M = {:.5g}, sample_time = {:.3f}, train_time = {:.3f}, used_time = {:.3f}'
                    .format(
                        beta,
                        step,
                        free_energy_mean.item(),
                        free_energy_std.item(),
                        entropy_mean.item(),
                        energy_mean.item(),
                        mag_mean.item(),
                        sample_time,
                        train_time,
                        used_time,
                    ))
                sample_time = 0
                train_time = 0

        with open(args.fname, 'a', newline='\n') as f:
            f.write('{} {} {:.3g} {:.8g} {:.8g} {:.8g} {:.8g}\n'.format(
                args.n,
                args.seed,
                beta,
                free_energy_mean.item(),
                free_energy_std.item(),
                energy_mean.item(),
                entropy_mean.item(),
            ))

        if args.ham == 'hop':
            ensure_dir(args.out_filename + '_sample/')
            np.savetxt('{}_sample/sample{:.2f}.txt'.format(
                args.out_filename, beta),
                       sample.cpu().numpy(),
                       delimiter=' ',
                       fmt='%d')
            np.savetxt('{}_sample/log_prob{:.2f}.txt'.format(
                args.out_filename, beta),
                       log_prob.cpu().detach().numpy(),
                       delimiter=' ',
                       fmt='%.5f')

        beta += args.beta_inc
Пример #4
0
def main():
    start_time = time()
    init_out_dir()
    last_step = get_last_ckpt_step()
    if last_step >= 0:
        my_log(f'\nCheckpoint found: {last_step}\n')
    else:
        clear_log()
    print_args()

    net_init, net_apply, net_init_cache, net_apply_fast = get_net()

    rng, rng_net = jrand.split(jrand.PRNGKey(args.seed))
    in_shape = (args.batch_size, args.L, args.L, 1)
    out_shape, params_init = net_init(rng_net, in_shape)

    _, cache_init = net_init_cache(params_init, jnp.zeros(in_shape), (-1, -1))

    # sample_fun = get_sample_fun(net_apply, None)
    sample_fun = get_sample_fun(net_apply_fast, cache_init)
    log_q_fun = get_log_q_fun(net_apply)

    need_beta_anneal = args.beta_anneal_step > 0

    opt_init, opt_update, get_params = optimizers.adam(args.lr)

    @jit
    def update(step, opt_state, rng):
        params = get_params(opt_state)
        rng, rng_sample = jrand.split(rng)
        spins = sample_fun(args.batch_size, params, rng_sample)
        log_q = log_q_fun(params, spins) / args.L**2
        energy = energy_fun(spins) / args.L**2

        def neg_log_Z_fun(params, spins):
            log_q = log_q_fun(params, spins) / args.L**2
            energy = energy_fun(spins) / args.L**2
            beta = args.beta
            if need_beta_anneal:
                beta *= jnp.minimum(step / args.beta_anneal_step, 1)
            neg_log_Z = log_q + beta * energy
            return neg_log_Z

        loss_fun = partial(expect,
                           log_q_fun,
                           neg_log_Z_fun,
                           mean_grad_expected_is_zero=True)
        grads = grad(loss_fun)(params, spins, spins)
        opt_state = opt_update(step, grads, opt_state)

        return spins, log_q, energy, opt_state, rng

    if last_step >= 0:
        params_init = load_ckpt(last_step)

    opt_state = opt_init(params_init)

    my_log('Training...')
    for step in range(last_step + 1, args.max_step + 1):
        spins, log_q, energy, opt_state, rng = update(step, opt_state, rng)

        if args.print_step and step % args.print_step == 0:
            # Use the final beta, not the annealed beta
            free_energy = log_q / args.beta + energy
            my_log(', '.join([
                f'step = {step}',
                f'F = {free_energy.mean():.8g}',
                f'F_std = {free_energy.std():.8g}',
                f'S = {-log_q.mean():.8g}',
                f'E = {energy.mean():.8g}',
                f'time = {time() - start_time:.3f}',
            ]))

        if args.save_step and step % args.save_step == 0:
            params = get_params(opt_state)
            save_ckpt(params, step)
Пример #5
0
def main():

    start_time = time.time()
    # initialize output dir
    init_out_dir()
    # check point
    if args.clear_checkpoint:
        clear_checkpoint()
    last_step = get_last_checkpoint_step()
    if last_step >= 0:
        my_log('\nCheckpoint found: {}\n'.format(last_step))
    else:
        clear_log()
    print_args()

    if args.net == 'pixelcnn_xy':
        net = PixelCNN(**vars(args))
    else:
        raise ValueError('Unknown net: {}'.format(args.net))
    net.to(args.device)
    my_log('{}\n'.format(net))

    # parameters of networks
    params = list(net.parameters())
    params = list(filter(lambda p: p.requires_grad,
                         params))  # parameters with gradients
    nparams = int(sum([np.prod(p.shape) for p in params]))
    my_log('Total number of trainable parameters: {}'.format(nparams))
    named_params = list(net.named_parameters())
    # optimizers
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(params, lr=args.lr)
    elif args.optimizer == 'sgdm':
        optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(params, lr=args.lr, alpha=0.99)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.9, 0.999))
    elif args.optimizer == 'adam0.5':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.5, 0.999))
    else:
        raise ValueError('Unknown optimizer: {}'.format(args.optimizer))
    # learning rates
    if args.lr_schedule:
        # 0.92**80 ~ 1e-3
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               factor=0.92,
                                                               patience=100,
                                                               threshold=1e-4,
                                                               min_lr=1e-6)
    # read last step
    if last_step >= 0:
        state = torch.load('{}_save/{}.state'.format(args.out_filename,
                                                     last_step))
        ignore_param(state['net'], net)
        net.load_state_dict(state['net'])
        if state.get('optimizer'):
            optimizer.load_state_dict(state['optimizer'])
        if args.lr_schedule and state.get('scheduler'):
            scheduler.load_state_dict(state['scheduler'])

    init_time = time.time() - start_time
    my_log('init_time = {:.3f}'.format(init_time))

    # start training
    my_log('Training...')
    sample_time = 0
    train_time = 0
    start_time = time.time()
    for step in range(last_step + 1, args.max_step + 1):
        optimizer.zero_grad()  # clear last step

        sample_start_time = time.time()
        with torch.no_grad():
            sample, x_hat = net.sample(
                args.batch_size
            )  # sample from networks with batch_size = 10**3 (default)
        assert not sample.requires_grad
        assert not x_hat.requires_grad
        sample_time += time.time() - sample_start_time

        train_start_time = time.time()

        # log probabilities
        log_prob = net.log_prob(sample, args.batch_size)

        # 0.998**9000 ~ 1e-8
        beta = args.beta * (1 - args.beta_anneal**step
                            )  # anneal process to avoid mode collapse
        with torch.no_grad():
            energy, vortices = xy.energy(sample, args.ham, args.lattice,
                                         args.boundary)
            loss = log_prob + beta * energy  # construct loss function(free energy)from configurations

        assert not energy.requires_grad
        assert not loss.requires_grad
        loss_reinforce = torch.mean((loss - loss.mean()) * log_prob)
        loss_reinforce.backward()  # back propagation

        if args.clip_grad:
            nn.utils.clip_grad_norm_(params, args.clip_grad)

        optimizer.step()

        if args.lr_schedule:
            scheduler.step(loss.mean())

        train_time += time.time() - train_start_time

        # export physical observables
        if args.print_step and step % args.print_step == 0:
            free_energy_mean = loss.mean() / beta / (args.L**2
                                                     )  # free energy density
            free_energy_std = loss.std() / beta / (args.L**2)
            entropy_mean = -log_prob.mean() / (args.L**2)  # entropy density
            energy_mean = (energy / (args.L**2)).mean()  # energy density
            energy_std = (energy / (args.L**2)).std()
            vortices = vortices.mean() / args.L**2  # vortices density

            # heat_capacity=(((energy/ (args.L**2))**2).mean()- ((energy/ (args.L**2)).mean())**2)  *(beta**2)

            # magnetization
            # mag = torch.cos(sample).sum(dim=(2,3)).mean(dim=0) # M_x (M_x,M_y)=(cos(theta), sin(theta))
            # mag_mean = mag.mean()
            # mag_sqr_mean = (mag**2).mean()
            # sus_mean = mag_sqr_mean/args.L**2

            # log
            if step > 0:
                sample_time /= args.print_step
                train_time /= args.print_step
            used_time = time.time() - start_time
            # hyperparameters in training
            my_log(
                'step = {}, lr = {:.3g}, loss={:.8g}, beta = {:.8g}, sample_time = {:.3f}, train_time = {:.3f}, used_time = {:.3f}'
                .format(
                    step,
                    optimizer.param_groups[0]['lr'],
                    loss.mean(),
                    beta,
                    sample_time,
                    train_time,
                    used_time,
                ))
            # observables
            my_log(
                'F = {:.8g}, F_std = {:.8g}, E = {:.8g}, E_std={:.8g}, v={:.8g}'
                .format(
                    free_energy_mean.item(),
                    free_energy_std.item(),
                    energy_mean.item(),
                    energy_std.item(),
                    vortices.item(),
                ))

            sample_time = 0
            train_time = 0
            # save sample
            if args.save_sample and step % args.save_step == 0:
                # save traning state
                # state = {
                #     'sample': sample,
                #     'x_hat': x_hat,
                #     'log_prob': log_prob,
                #     'energy': energy,
                #     'loss': loss,
                # }
                # torch.save(state, '{}_save/{}.sample'.format(
                #     args.out_filename, step))

                # Recognize the Phase Transition
                # helicity
                with torch.no_grad():
                    correlations = helicity(sample)
                helicity_modulus = -((energy / args.L**2).mean()) - (
                    args.beta * correlations**2 / args.L**2).mean()
                my_log('Rho={:.8g}'.format(helicity_modulus.item()))

                # save configurations
                sample_array = sample.cpu().numpy()
                np.savetxt(
                    '{}_save/sample{}.txt'.format(args.out_filename, step),
                    sample_array.reshape(args.batch_size, -1))
                # save observables
                np.savetxt(
                    '{}_save/results{}.csv'.format(args.out_filename, step), [
                        beta,
                        step,
                        free_energy_mean,
                        free_energy_std,
                        energy_mean,
                        energy_std,
                        vortices,
                        helicity_modulus,
                    ])

        # save net
        if (args.out_filename and args.save_step
                and step % args.save_step == 0):
            state = {
                'net': net.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            if args.lr_schedule:
                state['scheduler'] = scheduler.state_dict()
            torch.save(state,
                       '{}_save/{}.state'.format(args.out_filename, step))

        # visualization in each visual_step
        if (args.out_filename and args.visual_step
                and step % args.visual_step == 0):
            # torchvision.utils.save_image(
            #     sample,
            #     '{}_img/{}.png'.format(args.out_filename, step),
            #     nrow=int(sqrt(sample.shape[0])),
            #     padding=0,
            #     normalize=True)
            # print sample
            if args.print_sample:
                x_hat_alpha = x_hat[:, 0, :, :].view(x_hat.shape[0],
                                                     -1).cpu().numpy()  # alpha
                x_hat_std1 = np.std(x_hat_alpha, axis=0).reshape([args.L] * 2)
                x_hat_beta = x_hat[:, 1, :, :].view(x_hat.shape[0],
                                                    -1).cpu().numpy()  # beta
                x_hat_std2 = np.std(x_hat_beta, axis=0).reshape([args.L] * 2)

                energy_np = energy.cpu().numpy()
                energy_count = np.stack(
                    np.unique(energy_np, return_counts=True)).T

                my_log(
                    '\nsample\n{}\nalpha\n{}\nbeta\n{}\nlog_prob\n{}\nenergy\n{}\nloss\n{}\nalpha_std\n{}\nbeta_std\n{}\nenergy_count\n{}\n'
                    .format(
                        sample[:args.print_sample, 0],
                        x_hat[:args.print_sample, 0],
                        x_hat[:args.print_sample, 1],
                        log_prob[:args.print_sample],
                        energy[:args.print_sample],
                        loss[:args.print_sample],
                        x_hat_std1,
                        x_hat_std2,
                        energy_count,
                    ))
            # print gradient
            if args.print_grad:
                my_log('grad max_abs min_abs mean std')
                for name, param in named_params:
                    if param.grad is not None:
                        grad = param.grad
                        grad_abs = torch.abs(grad)
                        my_log('{} {:.3g} {:.3g} {:.3g} {:.3g}'.format(
                            name,
                            torch.max(grad_abs).item(),
                            torch.min(grad_abs).item(),
                            torch.mean(grad).item(),
                            torch.std(grad).item(),
                        ))
                    else:
                        my_log('{} None'.format(name))
                my_log('')
Пример #6
0
def main():
    start_time = time.time()

    utils.init_out_dir()
    last_epoch = utils.get_last_checkpoint_step()
    if last_epoch >= args.epoch:
        exit()
    if last_epoch >= 0:
        my_log('\nCheckpoint found: {}\n'.format(last_epoch))
    else:
        utils.clear_log()
    utils.print_args()

    model = RNN(args.device, Number_qubits = args.N,charset_length = args.charset_length,\
        hidden_size = args.hidden_size, num_layers = args.num_layers)
    data = prepare_data(args.N, './data.txt')
    ghz = GHZ(Number_qubits=args.N)

    model.train(True)
    my_log('Total nparams: {}'.format(utils.get_nparams(model)))
    model.to(args.device)

    params = [x for x in model.parameters() if x.requires_grad]
    optimizer = torch.optim.AdamW(params,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    if last_epoch >= 0:
        utils.load_checkpoint(last_epoch, model, optimizer)

    init_time = time.time() - start_time
    my_log('init_time = {:.3f}'.format(init_time))

    my_log('Training...')
    start_time = time.time()
    best_fid = 0
    trigger = 0  # once current fid is less than best fid, trigger+=1
    for epoch_idx in range(last_epoch + 1, args.epoch + 1):
        for batch_idx in range(int(args.Ns / args.batch_size)):
            optimizer.zero_grad()
            # idx = np.random.randint(low=0,high=int(args.Ns-1),size=(args.batch_size,))
            idx = np.arange(args.batch_size) + batch_idx * args.batch_size
            train_data = data[idx]
            loss = -model.log_prob(
                torch.from_numpy(train_data).to(args.device)).mean()
            loss.backward()
            if args.clip_grad:
                clip_grad_norm_(params, args.clip_grad)
            optimizer.step()
        print('epoch_idx {} current loss {:.8g}'.format(
            epoch_idx, loss.item()))
        print('Evaluating...')
        # Evaluation
        current_fid = classical_fidelity(model, ghz, print_prob=False)
        if current_fid > best_fid:
            trigger = 0  # reset
            my_log('epoch_idx {} loss {:.8g} fid {} time {:.3f}'.format(
                epoch_idx, loss.item(), current_fid,
                time.time() - start_time))
            best_fid = current_fid
            if (args.out_filename and args.save_epoch
                    and epoch_idx % args.save_epoch == 0):
                state = {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }
                torch.save(
                    state, '{}_save/{}.state'.format(args.out_filename,
                                                     epoch_idx))
        else:
            trigger = trigger + 1
        if trigger > 4:
            break
Пример #7
0
def main():
    start_time = time.time()

    utils.init_out_dir()
    last_epoch = utils.get_last_checkpoint_step()
    if last_epoch >= args.epoch:
        exit()
    if last_epoch >= 0:
        my_log('\nCheckpoint found: {}\n'.format(last_epoch))
    else:
        utils.clear_log()
    utils.print_args()

    flow = build_mera()
    flow.train(True)
    my_log('nparams in each RG layer: {}'.format(
        [utils.get_nparams(layer) for layer in flow.layers]))
    my_log('Total nparams: {}'.format(utils.get_nparams(flow)))

    # Use multiple GPUs
    if args.cuda and torch.cuda.device_count() > 1:
        flow = utils.data_parallel_wrap(flow)

    params = [x for x in flow.parameters() if x.requires_grad]
    optimizer = torch.optim.AdamW(params,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)

    if last_epoch >= 0:
        utils.load_checkpoint(last_epoch, flow, optimizer)

    train_split, val_split, data_info = utils.load_dataset()
    train_loader = torch.utils.data.DataLoader(train_split,
                                               args.batch_size,
                                               shuffle=True,
                                               num_workers=1,
                                               pin_memory=True)

    init_time = time.time() - start_time
    my_log('init_time = {:.3f}'.format(init_time))

    my_log('Training...')
    start_time = time.time()
    for epoch_idx in range(last_epoch + 1, args.epoch + 1):
        for batch_idx, (x, _) in enumerate(train_loader):
            optimizer.zero_grad()

            x = x.to(args.device)
            x, ldj_logit = utils.logit_transform(x)
            log_prob = flow.log_prob(x)
            loss = -(log_prob + ldj_logit) / (args.nchannels * args.L**2)
            loss_mean = loss.mean()
            loss_std = loss.std()

            utils.check_nan(loss_mean)

            loss_mean.backward()
            if args.clip_grad:
                clip_grad_norm_(params, args.clip_grad)
            optimizer.step()

            if args.print_step and batch_idx % args.print_step == 0:
                bit_per_dim = (loss_mean.item() + log(256)) / log(2)
                my_log(
                    'epoch {} batch {} bpp {:.8g} loss {:.8g} +- {:.8g} time {:.3f}'
                    .format(
                        epoch_idx,
                        batch_idx,
                        bit_per_dim,
                        loss_mean.item(),
                        loss_std.item(),
                        time.time() - start_time,
                    ))

        if (args.out_filename and args.save_epoch
                and epoch_idx % args.save_epoch == 0):
            state = {
                'flow': flow.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(state,
                       '{}_save/{}.state'.format(args.out_filename, epoch_idx))

            if epoch_idx > 0 and (epoch_idx - 1) % args.keep_epoch != 0:
                os.remove('{}_save/{}.state'.format(args.out_filename,
                                                    epoch_idx - 1))

        if (args.plot_filename and args.plot_epoch
                and epoch_idx % args.plot_epoch == 0):
            with torch.no_grad():
                do_plot(flow, epoch_idx)
Пример #8
0
def main():
    start_time = time.time()

    init_out_dir()
    print_args()

    fsave_name = 'n{}b{:.2f}D{}.pickle'.format(args.n, args.beta, args.seed)
    with open(fsave_name, 'rb') as f:
        sk_true = pickle.load(f)

    sk_true.J = sk_true.J.to(args.device)
    assert args.n == sk_true.n
    assert args.beta == sk_true.beta

    sk = SKModel(args.n, args.beta, args.device, seed=args.seed)
    my_log('diff = {:.8g}'.format(sk_true.J_diff(sk.J)))

    C_data = sk_true.C_model.to(args.device, default_dtype_torch)
    C = C_data
    C_inv = torch.inverse(C)

    J_nmf = (torch.eye(args.n).to(C.device) - C_inv) / args.beta
    idx = range(args.n)
    J_nmf[idx, idx] = 0
    diff_nmf = sk_true.J_diff(J_nmf.to(args.device))
    my_log('diff_NMF = {:.8g}'.format(diff_nmf))

    J_IP = 1 / 2 * torch.log((1 + C) / (1 - C + args.epsilon)) / args.beta
    idx = range(args.n)
    J_IP[idx, idx] = 0
    diff_IP = sk_true.J_diff(J_IP.to(args.device))
    my_log('diff_IP = {:.8g}'.format(diff_IP))

    J_SM = -C_inv + J_IP * args.beta - C / (1 - C**2 + args.epsilon)
    J_SM /= args.beta
    idx = range(args.n)
    J_SM[idx, idx] = 0
    diff_SM = sk_true.J_diff(J_SM.to(args.device))
    my_log('diff_SM = {:.8g}'.format(diff_SM))

    sk.J.data[:] = 0

    net = MADE(**vars(args))
    net.to(args.device)
    my_log('{}\n'.format(net))

    params = list(net.parameters())
    params = [sk.J]
    params = list(filter(lambda p: p.requires_grad, params))
    nparams = int(sum([np.prod(p.shape) for p in params]))
    my_log('Total number of trainable parameters: {}'.format(nparams))
    params2 = [sk.J]

    optimizer = torch.optim.Adam(params, lr=args.lr)
    optimizer2 = torch.optim.Adam(params2, lr=0.01)

    init_time = time.time() - start_time
    my_log('init_time = {:.3f}'.format(init_time))

    my_log('Training...')
    sample_time = 0
    train_time = 0
    start_time = time.time()
    diff = 10000
    for step2 in range(5000):
        optimizer2.zero_grad()
        for step in range(args.max_step + 1):
            optimizer.zero_grad()

            sample_start_time = time.time()
            with torch.no_grad():
                sample, x_hat = net.sample(args.batch_size)
            assert not sample.requires_grad
            assert not x_hat.requires_grad
            sample_time += time.time() - sample_start_time

            train_start_time = time.time()

            log_prob = net.log_prob(sample)
            with torch.no_grad():
                energy = sk.energy(sample)
                loss = log_prob + args.beta * energy
            assert not energy.requires_grad
            assert not loss.requires_grad
            loss_reinforce = torch.mean((loss - loss.mean()) * log_prob)
            loss_reinforce.backward()

            optimizer.step()

            train_time += time.time() - train_start_time

        with torch.no_grad():
            sample, x_hat = net.sample(10**6)
        assert not sample.requires_grad
        assert not x_hat.requires_grad

        m_model = torch.mean(sample, 0).view(sample.shape[1], 1)
        C_model = sample.t() @ sample / sample.shape[0] - m_model @ m_model.t()

        C_diff = C_data - C_model
        C_norm = C_diff.norm(2)

        sk.J.grad = -C_diff / (C_norm + args.epsilon)
        optimizer2.step()

        diff_old = diff
        diff = sk_true.J_diff(sk.J)
        my_log('# {}, diff_J = {}, diff_C = {}'.format(
            step2, diff, torch.sqrt(torch.mean(C_diff**2))))

        idx = range(args.n)
        sk.J.data[idx, idx] = 0

    with open(args.fname, 'a', newline='\n') as f:
        f.write('{} {} {:.3g} {:.8g}\n'.format(args.n, args.seed, args.beta,
                                               diff))