def main():
    utils.init_out_dir()
    last_epoch = utils.get_last_checkpoint_step()
    if last_epoch >= args.epoch:
        exit()
    if last_epoch >= 0:
        my_log('\nCheckpoint found: {}\n'.format(last_epoch))
    else:
        utils.clear_log()

    model = RNN(args.device, Number_qubits = args.N,charset_length = args.charset_length,\
        hidden_size = args.hidden_size, num_layers = args.num_layers)

    model.train(False)
    print('number of qubits: ', model.Number_qubits)
    my_log('Total nparams: {}'.format(utils.get_nparams(model)))
    model.to(args.device)
    params = [x for x in model.parameters() if x.requires_grad]
    optimizer = torch.optim.AdamW(params,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    if last_epoch >= 0:
        utils.load_checkpoint(last_epoch, model, optimizer)

    # Quantum state
    ghz = GHZ(Number_qubits=args.N)
    c_fidelity = classical_fidelity(model, ghz)
    # c_fidelity = cfid(model, ghz, './data.txt')
    print(c_fidelity)
示例#2
0
def main():
    flow = build_mera()
    last_epoch = utils.get_last_checkpoint_step()
    utils.load_checkpoint(last_epoch, flow)
    flow.train(False)

    shape = (16, args.nchannels, args.L, args.L)
    prior_low = Laplace(torch.tensor(0.), torch.tensor(T_low / sqrt(2)))
    z = prior_low.sample(shape)
    prior_high = Laplace(torch.tensor(0.), torch.tensor(T_high / sqrt(2)))
    z_high = prior_high.sample(shape)
    k = 2**level_cutoff
    z[:, :, ::k, ::k] = z_high[:, :, ::k, ::k]
    z = z.to(args.device)

    with torch.no_grad():
        x, _ = flow.inverse(z)

    samples = x.permute(0, 2, 3, 1).detach().cpu().numpy()
    samples = 1 / (1 + np.exp(-samples))

    fig, axes = plt.subplots(4, 4, figsize=(4, 4), sharex=True, sharey=True)
    for i in range(4):
        for j in range(4):
            ax = axes[i, j]
            ax.imshow(samples[j * 4 + i])
            ax.axis('off')
    plt.tight_layout()
    plt.savefig('./mix_T.pdf', bbox_inches='tight')
示例#3
0
def main():
    start_time = time.time()

    init_out_dir()
    if args.clear_checkpoint:
        clear_checkpoint()
    last_step = get_last_checkpoint_step()
    if last_step >= 0:
        my_log('\nCheckpoint found: {}\n'.format(last_step))
    else:
        clear_log()
    print_args()

    if args.net == 'made':
        net = MADE(**vars(args))
    elif args.net == 'pixelcnn':
        net = PixelCNN(**vars(args))
    elif args.net == 'bernoulli':
        net = BernoulliMixture(**vars(args))
    else:
        raise ValueError('Unknown net: {}'.format(args.net))
    net.to(args.device)
    my_log('{}\n'.format(net))

    params = list(net.parameters())
    params = list(filter(lambda p: p.requires_grad, params))
    nparams = int(sum([np.prod(p.shape) for p in params]))
    my_log('Total number of trainable parameters: {}'.format(nparams))
    named_params = list(net.named_parameters())

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(params, lr=args.lr)
    elif args.optimizer == 'sgdm':
        optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(params, lr=args.lr, alpha=0.99)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.9, 0.999))
    elif args.optimizer == 'adam0.5':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.5, 0.999))
    else:
        raise ValueError('Unknown optimizer: {}'.format(args.optimizer))

    if args.lr_schedule:
        # 0.92**80 ~ 1e-3
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, factor=0.92, patience=100, threshold=1e-4, min_lr=1e-6)

    if last_step >= 0:
        state = torch.load('{}_save/{}.state'.format(args.out_filename,
                                                     last_step))
        ignore_param(state['net'], net)
        net.load_state_dict(state['net'])
        if state.get('optimizer'):
            optimizer.load_state_dict(state['optimizer'])
        if args.lr_schedule and state.get('scheduler'):
            scheduler.load_state_dict(state['scheduler'])

    init_time = time.time() - start_time
    my_log('init_time = {:.3f}'.format(init_time))

    my_log('Training...')
    sample_time = 0
    train_time = 0
    start_time = time.time()
    for step in range(last_step + 1, args.max_step + 1):
        optimizer.zero_grad()

        sample_start_time = time.time()
        with torch.no_grad():
            sample, x_hat = net.sample(args.batch_size)
        assert not sample.requires_grad
        assert not x_hat.requires_grad
        sample_time += time.time() - sample_start_time

        train_start_time = time.time()

        log_prob = net.log_prob(sample)
        # 0.998**9000 ~ 1e-8
        beta = args.beta * (1 - args.beta_anneal**step)
        with torch.no_grad():
            energy = ising.energy(sample, args.ham, args.lattice,
                                  args.boundary)
            loss = log_prob + beta * energy
        assert not energy.requires_grad
        assert not loss.requires_grad
        loss_reinforce = torch.mean((loss - loss.mean()) * log_prob)
        loss_reinforce.backward()

        if args.clip_grad:
            nn.utils.clip_grad_norm_(params, args.clip_grad)

        optimizer.step()

        if args.lr_schedule:
            scheduler.step(loss.mean())

        train_time += time.time() - train_start_time

        if args.print_step and step % args.print_step == 0:
            free_energy_mean = loss.mean() / args.beta / args.L**2
            free_energy_std = loss.std() / args.beta / args.L**2
            entropy_mean = -log_prob.mean() / args.L**2
            energy_mean = energy.mean() / args.L**2
            mag = sample.mean(dim=0)
            mag_mean = mag.mean()
            mag_sqr_mean = (mag**2).mean()
            if step > 0:
                sample_time /= args.print_step
                train_time /= args.print_step
            used_time = time.time() - start_time
            my_log(
                'step = {}, F = {:.8g}, F_std = {:.8g}, S = {:.8g}, E = {:.8g}, M = {:.8g}, Q = {:.8g}, lr = {:.3g}, beta = {:.8g}, sample_time = {:.3f}, train_time = {:.3f}, used_time = {:.3f}'
                .format(
                    step,
                    free_energy_mean.item(),
                    free_energy_std.item(),
                    entropy_mean.item(),
                    energy_mean.item(),
                    mag_mean.item(),
                    mag_sqr_mean.item(),
                    optimizer.param_groups[0]['lr'],
                    beta,
                    sample_time,
                    train_time,
                    used_time,
                ))
            sample_time = 0
            train_time = 0

            if args.save_sample:
                state = {
                    'sample': sample,
                    'x_hat': x_hat,
                    'log_prob': log_prob,
                    'energy': energy,
                    'loss': loss,
                }
                torch.save(state, '{}_save/{}.sample'.format(
                    args.out_filename, step))

        if (args.out_filename and args.save_step
                and step % args.save_step == 0):
            state = {
                'net': net.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            if args.lr_schedule:
                state['scheduler'] = scheduler.state_dict()
            torch.save(state, '{}_save/{}.state'.format(
                args.out_filename, step))

        if (args.out_filename and args.visual_step
                and step % args.visual_step == 0):
            torchvision.utils.save_image(
                sample,
                '{}_img/{}.png'.format(args.out_filename, step),
                nrow=int(sqrt(sample.shape[0])),
                padding=0,
                normalize=True)

            if args.print_sample:
                x_hat_np = x_hat.view(x_hat.shape[0], -1).cpu().numpy()
                x_hat_std = np.std(x_hat_np, axis=0).reshape([args.L] * 2)

                x_hat_cov = np.cov(x_hat_np.T)
                x_hat_cov_diag = np.diag(x_hat_cov)
                x_hat_corr = x_hat_cov / (
                    sqrt(x_hat_cov_diag[:, None] * x_hat_cov_diag[None, :]) +
                    args.epsilon)
                x_hat_corr = np.tril(x_hat_corr, -1)
                x_hat_corr = np.max(np.abs(x_hat_corr), axis=1)
                x_hat_corr = x_hat_corr.reshape([args.L] * 2)

                energy_np = energy.cpu().numpy()
                energy_count = np.stack(
                    np.unique(energy_np, return_counts=True)).T

                my_log(
                    '\nsample\n{}\nx_hat\n{}\nlog_prob\n{}\nenergy\n{}\nloss\n{}\nx_hat_std\n{}\nx_hat_corr\n{}\nenergy_count\n{}\n'
                    .format(
                        sample[:args.print_sample, 0],
                        x_hat[:args.print_sample, 0],
                        log_prob[:args.print_sample],
                        energy[:args.print_sample],
                        loss[:args.print_sample],
                        x_hat_std,
                        x_hat_corr,
                        energy_count,
                    ))

            if args.print_grad:
                my_log('grad max_abs min_abs mean std')
                for name, param in named_params:
                    if param.grad is not None:
                        grad = param.grad
                        grad_abs = torch.abs(grad)
                        my_log('{} {:.3g} {:.3g} {:.3g} {:.3g}'.format(
                            name,
                            torch.max(grad_abs).item(),
                            torch.min(grad_abs).item(),
                            torch.mean(grad).item(),
                            torch.std(grad).item(),
                        ))
                    else:
                        my_log('{} None'.format(name))
                my_log('')
示例#4
0
def main():

    start_time = time.time()
    # initialize output dir
    init_out_dir()
    # check point
    if args.clear_checkpoint:
        clear_checkpoint()
    last_step = get_last_checkpoint_step()
    if last_step >= 0:
        my_log('\nCheckpoint found: {}\n'.format(last_step))
    else:
        clear_log()
    print_args()

    if args.net == 'pixelcnn_xy':
        net = PixelCNN(**vars(args))
    else:
        raise ValueError('Unknown net: {}'.format(args.net))
    net.to(args.device)
    my_log('{}\n'.format(net))

    # parameters of networks
    params = list(net.parameters())
    params = list(filter(lambda p: p.requires_grad,
                         params))  # parameters with gradients
    nparams = int(sum([np.prod(p.shape) for p in params]))
    my_log('Total number of trainable parameters: {}'.format(nparams))
    named_params = list(net.named_parameters())
    # optimizers
    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(params, lr=args.lr)
    elif args.optimizer == 'sgdm':
        optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(params, lr=args.lr, alpha=0.99)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.9, 0.999))
    elif args.optimizer == 'adam0.5':
        optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.5, 0.999))
    else:
        raise ValueError('Unknown optimizer: {}'.format(args.optimizer))
    # learning rates
    if args.lr_schedule:
        # 0.92**80 ~ 1e-3
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               factor=0.92,
                                                               patience=100,
                                                               threshold=1e-4,
                                                               min_lr=1e-6)
    # read last step
    if last_step >= 0:
        state = torch.load('{}_save/{}.state'.format(args.out_filename,
                                                     last_step))
        ignore_param(state['net'], net)
        net.load_state_dict(state['net'])
        if state.get('optimizer'):
            optimizer.load_state_dict(state['optimizer'])
        if args.lr_schedule and state.get('scheduler'):
            scheduler.load_state_dict(state['scheduler'])

    init_time = time.time() - start_time
    my_log('init_time = {:.3f}'.format(init_time))

    # start training
    my_log('Training...')
    sample_time = 0
    train_time = 0
    start_time = time.time()
    for step in range(last_step + 1, args.max_step + 1):
        optimizer.zero_grad()  # clear last step

        sample_start_time = time.time()
        with torch.no_grad():
            sample, x_hat = net.sample(
                args.batch_size
            )  # sample from networks with batch_size = 10**3 (default)
        assert not sample.requires_grad
        assert not x_hat.requires_grad
        sample_time += time.time() - sample_start_time

        train_start_time = time.time()

        # log probabilities
        log_prob = net.log_prob(sample, args.batch_size)

        # 0.998**9000 ~ 1e-8
        beta = args.beta * (1 - args.beta_anneal**step
                            )  # anneal process to avoid mode collapse
        with torch.no_grad():
            energy, vortices = xy.energy(sample, args.ham, args.lattice,
                                         args.boundary)
            loss = log_prob + beta * energy  # construct loss function(free energy)from configurations

        assert not energy.requires_grad
        assert not loss.requires_grad
        loss_reinforce = torch.mean((loss - loss.mean()) * log_prob)
        loss_reinforce.backward()  # back propagation

        if args.clip_grad:
            nn.utils.clip_grad_norm_(params, args.clip_grad)

        optimizer.step()

        if args.lr_schedule:
            scheduler.step(loss.mean())

        train_time += time.time() - train_start_time

        # export physical observables
        if args.print_step and step % args.print_step == 0:
            free_energy_mean = loss.mean() / beta / (args.L**2
                                                     )  # free energy density
            free_energy_std = loss.std() / beta / (args.L**2)
            entropy_mean = -log_prob.mean() / (args.L**2)  # entropy density
            energy_mean = (energy / (args.L**2)).mean()  # energy density
            energy_std = (energy / (args.L**2)).std()
            vortices = vortices.mean() / args.L**2  # vortices density

            # heat_capacity=(((energy/ (args.L**2))**2).mean()- ((energy/ (args.L**2)).mean())**2)  *(beta**2)

            # magnetization
            # mag = torch.cos(sample).sum(dim=(2,3)).mean(dim=0) # M_x (M_x,M_y)=(cos(theta), sin(theta))
            # mag_mean = mag.mean()
            # mag_sqr_mean = (mag**2).mean()
            # sus_mean = mag_sqr_mean/args.L**2

            # log
            if step > 0:
                sample_time /= args.print_step
                train_time /= args.print_step
            used_time = time.time() - start_time
            # hyperparameters in training
            my_log(
                'step = {}, lr = {:.3g}, loss={:.8g}, beta = {:.8g}, sample_time = {:.3f}, train_time = {:.3f}, used_time = {:.3f}'
                .format(
                    step,
                    optimizer.param_groups[0]['lr'],
                    loss.mean(),
                    beta,
                    sample_time,
                    train_time,
                    used_time,
                ))
            # observables
            my_log(
                'F = {:.8g}, F_std = {:.8g}, E = {:.8g}, E_std={:.8g}, v={:.8g}'
                .format(
                    free_energy_mean.item(),
                    free_energy_std.item(),
                    energy_mean.item(),
                    energy_std.item(),
                    vortices.item(),
                ))

            sample_time = 0
            train_time = 0
            # save sample
            if args.save_sample and step % args.save_step == 0:
                # save traning state
                # state = {
                #     'sample': sample,
                #     'x_hat': x_hat,
                #     'log_prob': log_prob,
                #     'energy': energy,
                #     'loss': loss,
                # }
                # torch.save(state, '{}_save/{}.sample'.format(
                #     args.out_filename, step))

                # Recognize the Phase Transition
                # helicity
                with torch.no_grad():
                    correlations = helicity(sample)
                helicity_modulus = -((energy / args.L**2).mean()) - (
                    args.beta * correlations**2 / args.L**2).mean()
                my_log('Rho={:.8g}'.format(helicity_modulus.item()))

                # save configurations
                sample_array = sample.cpu().numpy()
                np.savetxt(
                    '{}_save/sample{}.txt'.format(args.out_filename, step),
                    sample_array.reshape(args.batch_size, -1))
                # save observables
                np.savetxt(
                    '{}_save/results{}.csv'.format(args.out_filename, step), [
                        beta,
                        step,
                        free_energy_mean,
                        free_energy_std,
                        energy_mean,
                        energy_std,
                        vortices,
                        helicity_modulus,
                    ])

        # save net
        if (args.out_filename and args.save_step
                and step % args.save_step == 0):
            state = {
                'net': net.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            if args.lr_schedule:
                state['scheduler'] = scheduler.state_dict()
            torch.save(state,
                       '{}_save/{}.state'.format(args.out_filename, step))

        # visualization in each visual_step
        if (args.out_filename and args.visual_step
                and step % args.visual_step == 0):
            # torchvision.utils.save_image(
            #     sample,
            #     '{}_img/{}.png'.format(args.out_filename, step),
            #     nrow=int(sqrt(sample.shape[0])),
            #     padding=0,
            #     normalize=True)
            # print sample
            if args.print_sample:
                x_hat_alpha = x_hat[:, 0, :, :].view(x_hat.shape[0],
                                                     -1).cpu().numpy()  # alpha
                x_hat_std1 = np.std(x_hat_alpha, axis=0).reshape([args.L] * 2)
                x_hat_beta = x_hat[:, 1, :, :].view(x_hat.shape[0],
                                                    -1).cpu().numpy()  # beta
                x_hat_std2 = np.std(x_hat_beta, axis=0).reshape([args.L] * 2)

                energy_np = energy.cpu().numpy()
                energy_count = np.stack(
                    np.unique(energy_np, return_counts=True)).T

                my_log(
                    '\nsample\n{}\nalpha\n{}\nbeta\n{}\nlog_prob\n{}\nenergy\n{}\nloss\n{}\nalpha_std\n{}\nbeta_std\n{}\nenergy_count\n{}\n'
                    .format(
                        sample[:args.print_sample, 0],
                        x_hat[:args.print_sample, 0],
                        x_hat[:args.print_sample, 1],
                        log_prob[:args.print_sample],
                        energy[:args.print_sample],
                        loss[:args.print_sample],
                        x_hat_std1,
                        x_hat_std2,
                        energy_count,
                    ))
            # print gradient
            if args.print_grad:
                my_log('grad max_abs min_abs mean std')
                for name, param in named_params:
                    if param.grad is not None:
                        grad = param.grad
                        grad_abs = torch.abs(grad)
                        my_log('{} {:.3g} {:.3g} {:.3g} {:.3g}'.format(
                            name,
                            torch.max(grad_abs).item(),
                            torch.min(grad_abs).item(),
                            torch.mean(grad).item(),
                            torch.std(grad).item(),
                        ))
                    else:
                        my_log('{} None'.format(name))
                my_log('')
示例#5
0
def main():
    start_time = time.time()

    utils.init_out_dir()
    last_epoch = utils.get_last_checkpoint_step()
    if last_epoch >= args.epoch:
        exit()
    if last_epoch >= 0:
        my_log('\nCheckpoint found: {}\n'.format(last_epoch))
    else:
        utils.clear_log()
    utils.print_args()

    model = RNN(args.device, Number_qubits = args.N,charset_length = args.charset_length,\
        hidden_size = args.hidden_size, num_layers = args.num_layers)
    data = prepare_data(args.N, './data.txt')
    ghz = GHZ(Number_qubits=args.N)

    model.train(True)
    my_log('Total nparams: {}'.format(utils.get_nparams(model)))
    model.to(args.device)

    params = [x for x in model.parameters() if x.requires_grad]
    optimizer = torch.optim.AdamW(params,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    if last_epoch >= 0:
        utils.load_checkpoint(last_epoch, model, optimizer)

    init_time = time.time() - start_time
    my_log('init_time = {:.3f}'.format(init_time))

    my_log('Training...')
    start_time = time.time()
    best_fid = 0
    trigger = 0  # once current fid is less than best fid, trigger+=1
    for epoch_idx in range(last_epoch + 1, args.epoch + 1):
        for batch_idx in range(int(args.Ns / args.batch_size)):
            optimizer.zero_grad()
            # idx = np.random.randint(low=0,high=int(args.Ns-1),size=(args.batch_size,))
            idx = np.arange(args.batch_size) + batch_idx * args.batch_size
            train_data = data[idx]
            loss = -model.log_prob(
                torch.from_numpy(train_data).to(args.device)).mean()
            loss.backward()
            if args.clip_grad:
                clip_grad_norm_(params, args.clip_grad)
            optimizer.step()
        print('epoch_idx {} current loss {:.8g}'.format(
            epoch_idx, loss.item()))
        print('Evaluating...')
        # Evaluation
        current_fid = classical_fidelity(model, ghz, print_prob=False)
        if current_fid > best_fid:
            trigger = 0  # reset
            my_log('epoch_idx {} loss {:.8g} fid {} time {:.3f}'.format(
                epoch_idx, loss.item(), current_fid,
                time.time() - start_time))
            best_fid = current_fid
            if (args.out_filename and args.save_epoch
                    and epoch_idx % args.save_epoch == 0):
                state = {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }
                torch.save(
                    state, '{}_save/{}.state'.format(args.out_filename,
                                                     epoch_idx))
        else:
            trigger = trigger + 1
        if trigger > 4:
            break
示例#6
0
def main():
    start_time = time.time()

    utils.init_out_dir()
    last_epoch = utils.get_last_checkpoint_step()
    if last_epoch >= args.epoch:
        exit()
    if last_epoch >= 0:
        my_log('\nCheckpoint found: {}\n'.format(last_epoch))
    else:
        utils.clear_log()
    utils.print_args()

    flow = build_mera()
    flow.train(True)
    my_log('nparams in each RG layer: {}'.format(
        [utils.get_nparams(layer) for layer in flow.layers]))
    my_log('Total nparams: {}'.format(utils.get_nparams(flow)))

    # Use multiple GPUs
    if args.cuda and torch.cuda.device_count() > 1:
        flow = utils.data_parallel_wrap(flow)

    params = [x for x in flow.parameters() if x.requires_grad]
    optimizer = torch.optim.AdamW(params,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)

    if last_epoch >= 0:
        utils.load_checkpoint(last_epoch, flow, optimizer)

    train_split, val_split, data_info = utils.load_dataset()
    train_loader = torch.utils.data.DataLoader(train_split,
                                               args.batch_size,
                                               shuffle=True,
                                               num_workers=1,
                                               pin_memory=True)

    init_time = time.time() - start_time
    my_log('init_time = {:.3f}'.format(init_time))

    my_log('Training...')
    start_time = time.time()
    for epoch_idx in range(last_epoch + 1, args.epoch + 1):
        for batch_idx, (x, _) in enumerate(train_loader):
            optimizer.zero_grad()

            x = x.to(args.device)
            x, ldj_logit = utils.logit_transform(x)
            log_prob = flow.log_prob(x)
            loss = -(log_prob + ldj_logit) / (args.nchannels * args.L**2)
            loss_mean = loss.mean()
            loss_std = loss.std()

            utils.check_nan(loss_mean)

            loss_mean.backward()
            if args.clip_grad:
                clip_grad_norm_(params, args.clip_grad)
            optimizer.step()

            if args.print_step and batch_idx % args.print_step == 0:
                bit_per_dim = (loss_mean.item() + log(256)) / log(2)
                my_log(
                    'epoch {} batch {} bpp {:.8g} loss {:.8g} +- {:.8g} time {:.3f}'
                    .format(
                        epoch_idx,
                        batch_idx,
                        bit_per_dim,
                        loss_mean.item(),
                        loss_std.item(),
                        time.time() - start_time,
                    ))

        if (args.out_filename and args.save_epoch
                and epoch_idx % args.save_epoch == 0):
            state = {
                'flow': flow.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(state,
                       '{}_save/{}.state'.format(args.out_filename, epoch_idx))

            if epoch_idx > 0 and (epoch_idx - 1) % args.keep_epoch != 0:
                os.remove('{}_save/{}.state'.format(args.out_filename,
                                                    epoch_idx - 1))

        if (args.plot_filename and args.plot_epoch
                and epoch_idx % args.plot_epoch == 0):
            with torch.no_grad():
                do_plot(flow, epoch_idx)