def loss(_params): ansatz_state = qnnops.alternating_layer_ansatz(_params, n_qubits, block_size, n_layers, rot_axis) return qnnops.energy(ham_matrix, ansatz_state) grad_fn = jax.grad(loss) if args.use_jit: grad_loss_fn = jax.jit(grad_fn) # Collect the norms of gradients params, grads = [], [] for step in tqdm(range(sample_size)): rng, param_rng = jax.random.split(rng) _, param = qnnops.initialize_circuit_params(param_rng, n_qubits, n_layers) grad = grad_fn(param) params.append(param) grads.append(grad) params = jnp.vstack(params) grads = jnp.vstack(grads) grad_norms = jnp.linalg.norm(grads, axis=1) grads_all_mean, grads_all_var = jnp.mean(grads).item(), jnp.var( grads).item() grads_single_mean, grads_single_var = jnp.mean( grads[:, 0]).item(), jnp.var(grads[:, 0]).item() grads_norm_mean, grads_norm_var = jnp.mean(grad_norms).item(), jnp.var( grad_norms).item()
expmgr.log_array(target_state=target_state) expmgr.save_array('target_state.npy', target_state) def circuit(params): return qnnops.alternating_layer_ansatz( params, n_qubits=n_qubits, block_size=block_size, n_layers=n_layers, rot_axis=rot_axis) def loss_fn(params): ansatz_state = circuit(params) return qnnops.state_norm(ansatz_state - target_state) / (2 ** n_qubits) rng = jax.random.PRNGKey(seed) _, init_params = qnnops.initialize_circuit_params(rng, n_qubits, n_layers) trained_params, _ = qnnops.train_loop( loss_fn, init_params, args.train_steps, args.lr, optimizer_name=args.optimizer_name, optimizer_args=args.optimizer_args, scheduler_name=args.scheduler_name, checkpoint_path=args.checkpoint_path, use_jit=args.use_jit, use_jacfwd=args.use_jacfwd ) optimized_state = circuit(trained_params) expmgr.log_array(optimized_state=optimized_state) expmgr.save_array('optimized_state.npy', optimized_state) expmgr.save_config(args)