def main(): utils.init_out_dir() last_epoch = utils.get_last_checkpoint_step() if last_epoch >= args.epoch: exit() if last_epoch >= 0: my_log('\nCheckpoint found: {}\n'.format(last_epoch)) else: utils.clear_log() model = RNN(args.device, Number_qubits = args.N,charset_length = args.charset_length,\ hidden_size = args.hidden_size, num_layers = args.num_layers) model.train(False) print('number of qubits: ', model.Number_qubits) my_log('Total nparams: {}'.format(utils.get_nparams(model))) model.to(args.device) params = [x for x in model.parameters() if x.requires_grad] optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.weight_decay) if last_epoch >= 0: utils.load_checkpoint(last_epoch, model, optimizer) # Quantum state ghz = GHZ(Number_qubits=args.N) c_fidelity = classical_fidelity(model, ghz) # c_fidelity = cfid(model, ghz, './data.txt') print(c_fidelity)
def main(): flow = build_mera() last_epoch = utils.get_last_checkpoint_step() utils.load_checkpoint(last_epoch, flow) flow.train(False) shape = (16, args.nchannels, args.L, args.L) prior_low = Laplace(torch.tensor(0.), torch.tensor(T_low / sqrt(2))) z = prior_low.sample(shape) prior_high = Laplace(torch.tensor(0.), torch.tensor(T_high / sqrt(2))) z_high = prior_high.sample(shape) k = 2**level_cutoff z[:, :, ::k, ::k] = z_high[:, :, ::k, ::k] z = z.to(args.device) with torch.no_grad(): x, _ = flow.inverse(z) samples = x.permute(0, 2, 3, 1).detach().cpu().numpy() samples = 1 / (1 + np.exp(-samples)) fig, axes = plt.subplots(4, 4, figsize=(4, 4), sharex=True, sharey=True) for i in range(4): for j in range(4): ax = axes[i, j] ax.imshow(samples[j * 4 + i]) ax.axis('off') plt.tight_layout() plt.savefig('./mix_T.pdf', bbox_inches='tight')
def main(): start_time = time.time() init_out_dir() if args.clear_checkpoint: clear_checkpoint() last_step = get_last_checkpoint_step() if last_step >= 0: my_log('\nCheckpoint found: {}\n'.format(last_step)) else: clear_log() print_args() if args.net == 'made': net = MADE(**vars(args)) elif args.net == 'pixelcnn': net = PixelCNN(**vars(args)) elif args.net == 'bernoulli': net = BernoulliMixture(**vars(args)) else: raise ValueError('Unknown net: {}'.format(args.net)) net.to(args.device) my_log('{}\n'.format(net)) params = list(net.parameters()) params = list(filter(lambda p: p.requires_grad, params)) nparams = int(sum([np.prod(p.shape) for p in params])) my_log('Total number of trainable parameters: {}'.format(nparams)) named_params = list(net.named_parameters()) if args.optimizer == 'sgd': optimizer = torch.optim.SGD(params, lr=args.lr) elif args.optimizer == 'sgdm': optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9) elif args.optimizer == 'rmsprop': optimizer = torch.optim.RMSprop(params, lr=args.lr, alpha=0.99) elif args.optimizer == 'adam': optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.9, 0.999)) elif args.optimizer == 'adam0.5': optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.5, 0.999)) else: raise ValueError('Unknown optimizer: {}'.format(args.optimizer)) if args.lr_schedule: # 0.92**80 ~ 1e-3 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=0.92, patience=100, threshold=1e-4, min_lr=1e-6) if last_step >= 0: state = torch.load('{}_save/{}.state'.format(args.out_filename, last_step)) ignore_param(state['net'], net) net.load_state_dict(state['net']) if state.get('optimizer'): optimizer.load_state_dict(state['optimizer']) if args.lr_schedule and state.get('scheduler'): scheduler.load_state_dict(state['scheduler']) init_time = time.time() - start_time my_log('init_time = {:.3f}'.format(init_time)) my_log('Training...') sample_time = 0 train_time = 0 start_time = time.time() for step in range(last_step + 1, args.max_step + 1): optimizer.zero_grad() sample_start_time = time.time() with torch.no_grad(): sample, x_hat = net.sample(args.batch_size) assert not sample.requires_grad assert not x_hat.requires_grad sample_time += time.time() - sample_start_time train_start_time = time.time() log_prob = net.log_prob(sample) # 0.998**9000 ~ 1e-8 beta = args.beta * (1 - args.beta_anneal**step) with torch.no_grad(): energy = ising.energy(sample, args.ham, args.lattice, args.boundary) loss = log_prob + beta * energy assert not energy.requires_grad assert not loss.requires_grad loss_reinforce = torch.mean((loss - loss.mean()) * log_prob) loss_reinforce.backward() if args.clip_grad: nn.utils.clip_grad_norm_(params, args.clip_grad) optimizer.step() if args.lr_schedule: scheduler.step(loss.mean()) train_time += time.time() - train_start_time if args.print_step and step % args.print_step == 0: free_energy_mean = loss.mean() / args.beta / args.L**2 free_energy_std = loss.std() / args.beta / args.L**2 entropy_mean = -log_prob.mean() / args.L**2 energy_mean = energy.mean() / args.L**2 mag = sample.mean(dim=0) mag_mean = mag.mean() mag_sqr_mean = (mag**2).mean() if step > 0: sample_time /= args.print_step train_time /= args.print_step used_time = time.time() - start_time my_log( 'step = {}, F = {:.8g}, F_std = {:.8g}, S = {:.8g}, E = {:.8g}, M = {:.8g}, Q = {:.8g}, lr = {:.3g}, beta = {:.8g}, sample_time = {:.3f}, train_time = {:.3f}, used_time = {:.3f}' .format( step, free_energy_mean.item(), free_energy_std.item(), entropy_mean.item(), energy_mean.item(), mag_mean.item(), mag_sqr_mean.item(), optimizer.param_groups[0]['lr'], beta, sample_time, train_time, used_time, )) sample_time = 0 train_time = 0 if args.save_sample: state = { 'sample': sample, 'x_hat': x_hat, 'log_prob': log_prob, 'energy': energy, 'loss': loss, } torch.save(state, '{}_save/{}.sample'.format( args.out_filename, step)) if (args.out_filename and args.save_step and step % args.save_step == 0): state = { 'net': net.state_dict(), 'optimizer': optimizer.state_dict(), } if args.lr_schedule: state['scheduler'] = scheduler.state_dict() torch.save(state, '{}_save/{}.state'.format( args.out_filename, step)) if (args.out_filename and args.visual_step and step % args.visual_step == 0): torchvision.utils.save_image( sample, '{}_img/{}.png'.format(args.out_filename, step), nrow=int(sqrt(sample.shape[0])), padding=0, normalize=True) if args.print_sample: x_hat_np = x_hat.view(x_hat.shape[0], -1).cpu().numpy() x_hat_std = np.std(x_hat_np, axis=0).reshape([args.L] * 2) x_hat_cov = np.cov(x_hat_np.T) x_hat_cov_diag = np.diag(x_hat_cov) x_hat_corr = x_hat_cov / ( sqrt(x_hat_cov_diag[:, None] * x_hat_cov_diag[None, :]) + args.epsilon) x_hat_corr = np.tril(x_hat_corr, -1) x_hat_corr = np.max(np.abs(x_hat_corr), axis=1) x_hat_corr = x_hat_corr.reshape([args.L] * 2) energy_np = energy.cpu().numpy() energy_count = np.stack( np.unique(energy_np, return_counts=True)).T my_log( '\nsample\n{}\nx_hat\n{}\nlog_prob\n{}\nenergy\n{}\nloss\n{}\nx_hat_std\n{}\nx_hat_corr\n{}\nenergy_count\n{}\n' .format( sample[:args.print_sample, 0], x_hat[:args.print_sample, 0], log_prob[:args.print_sample], energy[:args.print_sample], loss[:args.print_sample], x_hat_std, x_hat_corr, energy_count, )) if args.print_grad: my_log('grad max_abs min_abs mean std') for name, param in named_params: if param.grad is not None: grad = param.grad grad_abs = torch.abs(grad) my_log('{} {:.3g} {:.3g} {:.3g} {:.3g}'.format( name, torch.max(grad_abs).item(), torch.min(grad_abs).item(), torch.mean(grad).item(), torch.std(grad).item(), )) else: my_log('{} None'.format(name)) my_log('')
def main(): start_time = time.time() # initialize output dir init_out_dir() # check point if args.clear_checkpoint: clear_checkpoint() last_step = get_last_checkpoint_step() if last_step >= 0: my_log('\nCheckpoint found: {}\n'.format(last_step)) else: clear_log() print_args() if args.net == 'pixelcnn_xy': net = PixelCNN(**vars(args)) else: raise ValueError('Unknown net: {}'.format(args.net)) net.to(args.device) my_log('{}\n'.format(net)) # parameters of networks params = list(net.parameters()) params = list(filter(lambda p: p.requires_grad, params)) # parameters with gradients nparams = int(sum([np.prod(p.shape) for p in params])) my_log('Total number of trainable parameters: {}'.format(nparams)) named_params = list(net.named_parameters()) # optimizers if args.optimizer == 'sgd': optimizer = torch.optim.SGD(params, lr=args.lr) elif args.optimizer == 'sgdm': optimizer = torch.optim.SGD(params, lr=args.lr, momentum=0.9) elif args.optimizer == 'rmsprop': optimizer = torch.optim.RMSprop(params, lr=args.lr, alpha=0.99) elif args.optimizer == 'adam': optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.9, 0.999)) elif args.optimizer == 'adam0.5': optimizer = torch.optim.Adam(params, lr=args.lr, betas=(0.5, 0.999)) else: raise ValueError('Unknown optimizer: {}'.format(args.optimizer)) # learning rates if args.lr_schedule: # 0.92**80 ~ 1e-3 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.92, patience=100, threshold=1e-4, min_lr=1e-6) # read last step if last_step >= 0: state = torch.load('{}_save/{}.state'.format(args.out_filename, last_step)) ignore_param(state['net'], net) net.load_state_dict(state['net']) if state.get('optimizer'): optimizer.load_state_dict(state['optimizer']) if args.lr_schedule and state.get('scheduler'): scheduler.load_state_dict(state['scheduler']) init_time = time.time() - start_time my_log('init_time = {:.3f}'.format(init_time)) # start training my_log('Training...') sample_time = 0 train_time = 0 start_time = time.time() for step in range(last_step + 1, args.max_step + 1): optimizer.zero_grad() # clear last step sample_start_time = time.time() with torch.no_grad(): sample, x_hat = net.sample( args.batch_size ) # sample from networks with batch_size = 10**3 (default) assert not sample.requires_grad assert not x_hat.requires_grad sample_time += time.time() - sample_start_time train_start_time = time.time() # log probabilities log_prob = net.log_prob(sample, args.batch_size) # 0.998**9000 ~ 1e-8 beta = args.beta * (1 - args.beta_anneal**step ) # anneal process to avoid mode collapse with torch.no_grad(): energy, vortices = xy.energy(sample, args.ham, args.lattice, args.boundary) loss = log_prob + beta * energy # construct loss function(free energy)from configurations assert not energy.requires_grad assert not loss.requires_grad loss_reinforce = torch.mean((loss - loss.mean()) * log_prob) loss_reinforce.backward() # back propagation if args.clip_grad: nn.utils.clip_grad_norm_(params, args.clip_grad) optimizer.step() if args.lr_schedule: scheduler.step(loss.mean()) train_time += time.time() - train_start_time # export physical observables if args.print_step and step % args.print_step == 0: free_energy_mean = loss.mean() / beta / (args.L**2 ) # free energy density free_energy_std = loss.std() / beta / (args.L**2) entropy_mean = -log_prob.mean() / (args.L**2) # entropy density energy_mean = (energy / (args.L**2)).mean() # energy density energy_std = (energy / (args.L**2)).std() vortices = vortices.mean() / args.L**2 # vortices density # heat_capacity=(((energy/ (args.L**2))**2).mean()- ((energy/ (args.L**2)).mean())**2) *(beta**2) # magnetization # mag = torch.cos(sample).sum(dim=(2,3)).mean(dim=0) # M_x (M_x,M_y)=(cos(theta), sin(theta)) # mag_mean = mag.mean() # mag_sqr_mean = (mag**2).mean() # sus_mean = mag_sqr_mean/args.L**2 # log if step > 0: sample_time /= args.print_step train_time /= args.print_step used_time = time.time() - start_time # hyperparameters in training my_log( 'step = {}, lr = {:.3g}, loss={:.8g}, beta = {:.8g}, sample_time = {:.3f}, train_time = {:.3f}, used_time = {:.3f}' .format( step, optimizer.param_groups[0]['lr'], loss.mean(), beta, sample_time, train_time, used_time, )) # observables my_log( 'F = {:.8g}, F_std = {:.8g}, E = {:.8g}, E_std={:.8g}, v={:.8g}' .format( free_energy_mean.item(), free_energy_std.item(), energy_mean.item(), energy_std.item(), vortices.item(), )) sample_time = 0 train_time = 0 # save sample if args.save_sample and step % args.save_step == 0: # save traning state # state = { # 'sample': sample, # 'x_hat': x_hat, # 'log_prob': log_prob, # 'energy': energy, # 'loss': loss, # } # torch.save(state, '{}_save/{}.sample'.format( # args.out_filename, step)) # Recognize the Phase Transition # helicity with torch.no_grad(): correlations = helicity(sample) helicity_modulus = -((energy / args.L**2).mean()) - ( args.beta * correlations**2 / args.L**2).mean() my_log('Rho={:.8g}'.format(helicity_modulus.item())) # save configurations sample_array = sample.cpu().numpy() np.savetxt( '{}_save/sample{}.txt'.format(args.out_filename, step), sample_array.reshape(args.batch_size, -1)) # save observables np.savetxt( '{}_save/results{}.csv'.format(args.out_filename, step), [ beta, step, free_energy_mean, free_energy_std, energy_mean, energy_std, vortices, helicity_modulus, ]) # save net if (args.out_filename and args.save_step and step % args.save_step == 0): state = { 'net': net.state_dict(), 'optimizer': optimizer.state_dict(), } if args.lr_schedule: state['scheduler'] = scheduler.state_dict() torch.save(state, '{}_save/{}.state'.format(args.out_filename, step)) # visualization in each visual_step if (args.out_filename and args.visual_step and step % args.visual_step == 0): # torchvision.utils.save_image( # sample, # '{}_img/{}.png'.format(args.out_filename, step), # nrow=int(sqrt(sample.shape[0])), # padding=0, # normalize=True) # print sample if args.print_sample: x_hat_alpha = x_hat[:, 0, :, :].view(x_hat.shape[0], -1).cpu().numpy() # alpha x_hat_std1 = np.std(x_hat_alpha, axis=0).reshape([args.L] * 2) x_hat_beta = x_hat[:, 1, :, :].view(x_hat.shape[0], -1).cpu().numpy() # beta x_hat_std2 = np.std(x_hat_beta, axis=0).reshape([args.L] * 2) energy_np = energy.cpu().numpy() energy_count = np.stack( np.unique(energy_np, return_counts=True)).T my_log( '\nsample\n{}\nalpha\n{}\nbeta\n{}\nlog_prob\n{}\nenergy\n{}\nloss\n{}\nalpha_std\n{}\nbeta_std\n{}\nenergy_count\n{}\n' .format( sample[:args.print_sample, 0], x_hat[:args.print_sample, 0], x_hat[:args.print_sample, 1], log_prob[:args.print_sample], energy[:args.print_sample], loss[:args.print_sample], x_hat_std1, x_hat_std2, energy_count, )) # print gradient if args.print_grad: my_log('grad max_abs min_abs mean std') for name, param in named_params: if param.grad is not None: grad = param.grad grad_abs = torch.abs(grad) my_log('{} {:.3g} {:.3g} {:.3g} {:.3g}'.format( name, torch.max(grad_abs).item(), torch.min(grad_abs).item(), torch.mean(grad).item(), torch.std(grad).item(), )) else: my_log('{} None'.format(name)) my_log('')
def main(): start_time = time.time() utils.init_out_dir() last_epoch = utils.get_last_checkpoint_step() if last_epoch >= args.epoch: exit() if last_epoch >= 0: my_log('\nCheckpoint found: {}\n'.format(last_epoch)) else: utils.clear_log() utils.print_args() model = RNN(args.device, Number_qubits = args.N,charset_length = args.charset_length,\ hidden_size = args.hidden_size, num_layers = args.num_layers) data = prepare_data(args.N, './data.txt') ghz = GHZ(Number_qubits=args.N) model.train(True) my_log('Total nparams: {}'.format(utils.get_nparams(model))) model.to(args.device) params = [x for x in model.parameters() if x.requires_grad] optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.weight_decay) if last_epoch >= 0: utils.load_checkpoint(last_epoch, model, optimizer) init_time = time.time() - start_time my_log('init_time = {:.3f}'.format(init_time)) my_log('Training...') start_time = time.time() best_fid = 0 trigger = 0 # once current fid is less than best fid, trigger+=1 for epoch_idx in range(last_epoch + 1, args.epoch + 1): for batch_idx in range(int(args.Ns / args.batch_size)): optimizer.zero_grad() # idx = np.random.randint(low=0,high=int(args.Ns-1),size=(args.batch_size,)) idx = np.arange(args.batch_size) + batch_idx * args.batch_size train_data = data[idx] loss = -model.log_prob( torch.from_numpy(train_data).to(args.device)).mean() loss.backward() if args.clip_grad: clip_grad_norm_(params, args.clip_grad) optimizer.step() print('epoch_idx {} current loss {:.8g}'.format( epoch_idx, loss.item())) print('Evaluating...') # Evaluation current_fid = classical_fidelity(model, ghz, print_prob=False) if current_fid > best_fid: trigger = 0 # reset my_log('epoch_idx {} loss {:.8g} fid {} time {:.3f}'.format( epoch_idx, loss.item(), current_fid, time.time() - start_time)) best_fid = current_fid if (args.out_filename and args.save_epoch and epoch_idx % args.save_epoch == 0): state = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save( state, '{}_save/{}.state'.format(args.out_filename, epoch_idx)) else: trigger = trigger + 1 if trigger > 4: break
def main(): start_time = time.time() utils.init_out_dir() last_epoch = utils.get_last_checkpoint_step() if last_epoch >= args.epoch: exit() if last_epoch >= 0: my_log('\nCheckpoint found: {}\n'.format(last_epoch)) else: utils.clear_log() utils.print_args() flow = build_mera() flow.train(True) my_log('nparams in each RG layer: {}'.format( [utils.get_nparams(layer) for layer in flow.layers])) my_log('Total nparams: {}'.format(utils.get_nparams(flow))) # Use multiple GPUs if args.cuda and torch.cuda.device_count() > 1: flow = utils.data_parallel_wrap(flow) params = [x for x in flow.parameters() if x.requires_grad] optimizer = torch.optim.AdamW(params, lr=args.lr, weight_decay=args.weight_decay) if last_epoch >= 0: utils.load_checkpoint(last_epoch, flow, optimizer) train_split, val_split, data_info = utils.load_dataset() train_loader = torch.utils.data.DataLoader(train_split, args.batch_size, shuffle=True, num_workers=1, pin_memory=True) init_time = time.time() - start_time my_log('init_time = {:.3f}'.format(init_time)) my_log('Training...') start_time = time.time() for epoch_idx in range(last_epoch + 1, args.epoch + 1): for batch_idx, (x, _) in enumerate(train_loader): optimizer.zero_grad() x = x.to(args.device) x, ldj_logit = utils.logit_transform(x) log_prob = flow.log_prob(x) loss = -(log_prob + ldj_logit) / (args.nchannels * args.L**2) loss_mean = loss.mean() loss_std = loss.std() utils.check_nan(loss_mean) loss_mean.backward() if args.clip_grad: clip_grad_norm_(params, args.clip_grad) optimizer.step() if args.print_step and batch_idx % args.print_step == 0: bit_per_dim = (loss_mean.item() + log(256)) / log(2) my_log( 'epoch {} batch {} bpp {:.8g} loss {:.8g} +- {:.8g} time {:.3f}' .format( epoch_idx, batch_idx, bit_per_dim, loss_mean.item(), loss_std.item(), time.time() - start_time, )) if (args.out_filename and args.save_epoch and epoch_idx % args.save_epoch == 0): state = { 'flow': flow.state_dict(), 'optimizer': optimizer.state_dict(), } torch.save(state, '{}_save/{}.state'.format(args.out_filename, epoch_idx)) if epoch_idx > 0 and (epoch_idx - 1) % args.keep_epoch != 0: os.remove('{}_save/{}.state'.format(args.out_filename, epoch_idx - 1)) if (args.plot_filename and args.plot_epoch and epoch_idx % args.plot_epoch == 0): with torch.no_grad(): do_plot(flow, epoch_idx)