def save_checkpoints(tag, step, w1, w2, params, history, optimizer_state): bends = jnp.vstack([w1, params, w2]) expmgr.save_array(f'bends_{tag}.npy', bends) expmgr.save_array(f'bends_{tag}.npy', bends) qnnops.save_checkpoint(f'checkpoint_{tag}.pkl', step, optimizer_state, history)
grads).item() grads_single_mean, grads_single_var = jnp.mean( grads[:, 0]).item(), jnp.var(grads[:, 0]).item() grads_norm_mean, grads_norm_var = jnp.mean(grad_norms).item(), jnp.var( grad_norms).item() logging_output = OrderedDict(grad_component_all_mean=grads_all_mean, grad_component_all_var=grads_all_var, grad_component_single_mean=grads_single_mean, grad_component_single_var=grads_single_var, grad_norm_mean=grads_norm_mean, grad_norm_var=grads_norm_var) expmgr.log(step=n_layers, logging_output=logging_output) wandb.log(dict( grad_component_all=wandb.Histogram( np_histogram=jnp.histogram(grads, bins=64, density=True)), grad_component_single=wandb.Histogram( np_histogram=jnp.histogram(grads[:, 0], bins=64, density=True)), grad_norm=wandb.Histogram( np_histogram=jnp.histogram(grad_norms, bins=64, density=True))), step=n_layers) suffix = f'Q{n_qubits}L{n_layers}R{rot_axis}BS{block_size}_g{g}h{h}' expmgr.save_array(f'params_{suffix}.npy', params) expmgr.save_array(f'grads_{suffix}.npy', grads) del params, grads gc.collect()
help='Omit the time tag from experiment name.') parser.add_argument('--use-jacfwd', dest='use_jacfwd', action='store_true', help='Enable the forward mode gradient computation (jacfwd).') parser.add_argument('--version', type=int, default=1, choices=[1, 2], help='qnnops version (Default: 1)') args = parser.parse_args() seed = args.seed n_qubits, n_layers, rot_axis = args.n_qubits, args.n_layers, args.rot_axis block_size = args.n_qubits exp_name = args.exp_name or f'Q{n_qubits}L{n_layers}R{rot_axis}BS{block_size} - S{seed} - LR{args.lr}' expmgr.init(project='expressibility', name=exp_name, config=args) target_state = qnnops.create_target_states(n_qubits, 1, seed=seed) expmgr.log_array(target_state=target_state) expmgr.save_array('target_state.npy', target_state) def circuit(params): return qnnops.alternating_layer_ansatz( params, n_qubits=n_qubits, block_size=block_size, n_layers=n_layers, rot_axis=rot_axis) def loss_fn(params): ansatz_state = circuit(params) return qnnops.state_norm(ansatz_state - target_state) / (2 ** n_qubits) rng = jax.random.PRNGKey(seed) _, init_params = qnnops.initialize_circuit_params(rng, n_qubits, n_layers) trained_params, _ = qnnops.train_loop(
def main(): parser = argparse.ArgumentParser('Mode Connectivity') parser.add_argument('--n-qubits', type=int, metavar='N', required=True, help='Number of qubits') parser.add_argument('--n-layers', type=int, metavar='N', required=True, help='Number of alternating layers') parser.add_argument('--rot-axis', type=str, metavar='R', required=True, choices=['x', 'y', 'z'], help='Direction of rotation gates.') parser.add_argument('--g', type=float, metavar='M', required=True, help='Transverse magnetic field') parser.add_argument('--h', type=float, metavar='M', required=True, help='Longitudinal magnetic field') parser.add_argument('--n-bends', type=int, metavar='N', default=3, help='Number of bends between endpoints') parser.add_argument('--train-steps', type=int, metavar='N', default=int(1e3), help='Number of training steps. (Default: 1000)') parser.add_argument('--batch-size', type=int, metavar='N', default=8, help='Batch size. (Default: 8)') parser.add_argument('--lr', type=float, metavar='LR', default=0.05, help='Initial value of learning rate. (Default: 0.05)') parser.add_argument('--log-every', type=int, metavar='N', default=1, help='Logging every N steps. (Default: 1)') parser.add_argument( '--seed', type=int, metavar='N', required=True, help='Random seed. For reproducibility, the value is set explicitly.') parser.add_argument('--model-seeds', type=int, metavar='N', nargs=2, required=True, help='Random seed used for model training.') parser.add_argument('--exp-name', type=str, metavar='NAME', default=None, help='Experiment name.') parser.add_argument( '--scheduler-name', type=str, metavar='NAME', default='exponential_decay', help=f'Scheduler name. Supports: {qnnops.supported_schedulers()} ' f'(Default: constant)') parser.add_argument( '--params-start', type=str, metavar='PATH', help='A file path of a checkpoint where the curve starts from') parser.add_argument( '--params-end', type=str, metavar='PATH', help='A file path of a checkpoint where the curve ends') parser.add_argument('--no-jit', dest='use_jit', action='store_false', help='Disable jit option to loss function.') parser.add_argument('--no-time-tag', dest='time_tag', action='store_false', help='Omit the time tag from experiment name.') parser.add_argument('--quiet', action='store_true', help='Quite mode (No training logs)') args = parser.parse_args() n_qubits, n_layers = args.n_qubits, args.n_layers if args.exp_name is None: args.exp_name = f'Q{n_qubits}L{n_layers}_nB{args.n_bends}' if args.model_seeds is None: params_start, params_end = args.params_start, args.params_end else: params_start, params_end = download_checkpoints( 'checkpoints', n_qubits=n_qubits, n_layers=n_layers, lr=0.05, # fixed lr that we used in training seed={'$in': args.model_seeds}) print('Initializing project') expmgr.init('ModeConnectivity', args.exp_name, args) print('Loading pretrained models') w1 = jnp.load(params_start) w2 = jnp.load(params_end) expmgr.save_array('endpoint_begin.npy', w1) expmgr.save_array('endpoint_end.npy', w2) print('Constructing Hamiltonian matrix') ham_matrix = qnnops.ising_hamiltonian(n_qubits=n_qubits, g=args.g, h=args.h) print('Define the loss function') def loss_fn(params): ansatz_state = qnnops.alternating_layer_ansatz(params, n_qubits, n_qubits, n_layers, args.rot_axis) return qnnops.energy(ham_matrix, ansatz_state) find_connected_curve(w1, w2, loss_fn, n_bends=args.n_bends, train_steps=args.train_steps, lr=args.lr, scheduler_name=args.scheduler_name, batch_size=args.batch_size, log_every=1, seed=args.seed, use_jit=args.use_jit)
def train_loop(loss_fn, init_params, train_steps=int(1e4), lr=0.01, optimizer_name='adam', optimizer_args=None, scheduler_name='constant', loss_args=None, early_stopping=False, monitor=None, log_every=1, checkpoint_path=None, use_jit=True, use_jacfwd=False): """ Training loop. Args: loss_fn: callable, loss function whose first argument must be params. init_params: jnp.array, initial trainable parameter values train_steps: int, total number of training steps lr: float, initial learning rate optimizer_name: str, optimizer name to be used. optimizer_args: dict, custom arguments for the optimizer. If None, default arguments will be used. scheduler_name: str, scheduler name. loss_args: dict, additional loss arguments if needed. early_stopping: bool, whether to early stop if the train loss value doesn't decrease further. (Not implemented yet) monitor: callable -> dict, monitoring function on training. log_every: int, logging every N steps. checkpoint_path: str, a checkpoint file path to resume. use_jit: bool, whether to use jit compilation. use_jacfwd: bool, enable the forward mode jax.jacfwd for gradient computation instead the reverse mode jax.grad (jax.jacrev) For backward compatibility, this option disables by default. But, later it will enable. (Default: False) Returns: params: jnp.array, optimized parameters history: dict, training history. """ assert monitor is None or callable( monitor), 'the monitoring function must be callable.' loss_args = loss_args or {} train_steps = int(train_steps) # to guarantee an integer type value. scheduler = get_scheduler(lr, train_steps, scheduler_name) init_fun, update_fun, get_params = get_optimizer(optimizer_name, optimizer_args, scheduler) if checkpoint_path: start_step, optimizer_state, history = load_checkpoint(checkpoint_path) min_loss = jnp.hstack(history['loss']).min() else: start_step = 0 optimizer_state = init_fun(init_params) history = {'loss': [], 'grad': [], 'params': []} min_loss = float('inf') try: if use_jacfwd: grad_loss_fn = jax.jacfwd(loss_fn) else: grad_loss_fn = jax.value_and_grad(loss_fn) if use_jit: grad_loss_fn = jax.jit(grad_loss_fn) for step in range(start_step, train_steps): params = get_params(optimizer_state) if use_jacfwd: loss = loss_fn(params, **loss_args) grad = grad_loss_fn(params, **loss_args) else: # for backward compatibility. It will be replaced by # jax.grad to make the consistency with jax.jacfwd. loss, grad = grad_loss_fn(params, **loss_args) optimizer_state = update_fun(step, grad, optimizer_state) updated_params = get_params(optimizer_state) grad = onp.array(grad) params = onp.array(params) updated_params = onp.array(updated_params) history['loss'].append(loss) history['grad'].append(grad) history['params'].append(params) if loss < min_loss: min_loss = loss expmgr.save_array('params_best.npy', updated_params) save_checkpoint('checkpoint_best.pkl', step, optimizer_state, history) if step % log_every == 0: grad_norm = jnp.linalg.norm(grad).item() logging_output = OrderedDict(loss=loss.item(), lr=scheduler(step), grad_norm=grad_norm) if monitor is not None: logging_output.update(monitor(params=params)) logging_output['min_loss'] = min_loss.item() expmgr.log(step, logging_output) expmgr.save_array('params_last.npy', updated_params) save_checkpoint('checkpoint_last.pkl', step, optimizer_state, history) if early_stopping: # TODO(jdk): implement early stopping feature. pass del loss, grad gc.collect() except Exception as e: print(e) print('Saving history object...') expmgr.save_history('history.npz', history) raise e else: expmgr.save_history('history.npz', history) return get_params(optimizer_state), history
help='qnnops version (Default: 1)') args = parser.parse_args() seed, seed_SYK = args.seed, args.seed_SYK n_qubits, max_n_layers, rot_axis = args.n_qubits, args.max_n_layers, args.rot_axis block_size = args.n_qubits sample_size = args.sample_size if not args.exp_name: args.exp_name = f'SYK4 - Q{n_qubits}R{rot_axis}BS{block_size} - SYK{seed_SYK} - S{seed} - SN{sample_size}' expmgr.init(project='SYK4BP', name=args.exp_name, config=args) # Construct the hamiltonian matrix of Ising model. ham_matrix = qnnops.SYK_hamiltonian(jax.random.PRNGKey(args.seed_SYK), n_qubits) expmgr.save_array('hamiltonian_matrix.npy', ham_matrix, upload_to_wandb=True) rng = jax.random.PRNGKey(seed) # Set of random seeds for parameter sampling M = int(log2(max_n_layers)) for i in range(1, M + 1): n_layers = 2**i print(f'{n_qubits} Qubits & {n_layers} Layers ({i}/{M})') def loss(_params): ansatz_state = qnnops.alternating_layer_ansatz(_params, n_qubits, block_size, n_layers, rot_axis) return qnnops.energy(ham_matrix, ansatz_state) grad_fn = jax.grad(loss)
temp = jnp.kron(temp, qnnops.PauliBasis[0]) gamma_matrices.append(temp) # Number of SYK4 interaction terms n_terms = int(factorial(n_gamma) / factorial(4) / factorial(n_gamma - 4)) # SYK4 random coupling couplings = jax.random.normal(key=jax.random.PRNGKey(args.seed_SYK), shape=(n_terms, ), dtype=jnp.float64) * jnp.sqrt(6 / (n_gamma ** 3)) ham_matrix = 0 for idx, (x, y, w, z) in enumerate(combinations(range(n_gamma), 4)): ham_matrix += (couplings[idx] / 4) * jnp.linalg.multi_dot([gamma_matrices[x], gamma_matrices[y], gamma_matrices[w], gamma_matrices[z]]) expmgr.save_array('hamiltonian_matrix.npy', ham_matrix, upload_to_wandb=False) eigval, eigvec = jnp.linalg.eigh(ham_matrix) eigvec = eigvec.T # Transpose such that eigvec[i] is an eigenvector, rather than eigenftn[:, i] ground_state = eigvec[0] next_to_ground_state = eigvec[1] print("The lowest eigenvalues (energy) and corresponding eigenvectors (state)") for i in range(min(5, len(eigval))): print(f'| {i}-th state energy={eigval[i]:.4f}') print(f'| {i}-th state vector={eigvec[i]}') expmgr.log_array( eigenvalues=eigval, ground_state=ground_state, next_to_ground_state=next_to_ground_state, )
# Set of random seeds for parameter sampling rng = jax.random.PRNGKey(seed) # Collect the norms of gradients grads = [] for step in range(sample_size): rng, param_rng = jax.random.split(rng) _, init_params = qnnops.initialize_circuit_params(param_rng, n_qubits, n_layers) grads.append(jax.grad(loss)(init_params)) wandb.log({'step': step}) grads = jnp.vstack(grads) grads_mean, grads_var, grads_norm = jnp.mean(grads, axis=0), jnp.var(grads, axis=0), jnp.linalg.norm(grads, axis=1) expmgr.save_array(expmgr.get_result_path('grads_mean.npy'), grads_mean) expmgr.save_array(expmgr.get_result_path('grads_var.npy'), grads_var) expmgr.save_array(expmgr.get_result_path('grads_norm.npy'), grads_norm) wandb.config.grads_mean = str(grads_mean) wandb.config.grads_var = str(grads_var) wandb.config.grads_norm = str(grads_norm) wandb.log({'means_mean': jnp.mean(grads_mean).item(), 'means_var': jnp.var(grads_mean).item(), 'vars_mean': jnp.mean(grads_var).item(), 'vars_var': jnp.var(grads_var).item(), 'norms_mean': jnp.mean(grads_norm).item(), 'norms_var': jnp.var(grads_norm).item()})
# Collect the norms of gradients grads = [] for step in range(sample_size): rng, param_rng = jax.random.split(rng) _, init_params = qnnops.initialize_circuit_params(param_rng, n_qubits, n_layers) grads.append(jax.grad(loss)(init_params)) wandb.log({'step': step}) grads = jnp.vstack(grads) grads_mean, grads_var, grads_norm = jnp.mean(grads, axis=0), jnp.var( grads, axis=0), jnp.linalg.norm(grads, axis=1) expmgr.save_array('grads_mean.npy', grads_mean) expmgr.save_array('grads_var.npy', grads_var) expmgr.save_array('grads_norm.npy', grads_norm) wandb.config.grads_mean = str(grads_mean) wandb.config.grads_var = str(grads_var) wandb.config.grads_norm = str(grads_norm) wandb.log({ 'means_mean': jnp.mean(grads_mean).item(), 'means_var': jnp.var(grads_mean).item(), 'vars_mean': jnp.mean(grads_var).item(), 'vars_var': jnp.var(grads_var).item(), 'norms_mean': jnp.mean(grads_norm).item(), 'norms_var': jnp.var(grads_norm).item() })