예제 #1
0
    def loss(_params):
        ansatz_state = qnnops.alternating_layer_ansatz(_params, n_qubits,
                                                       block_size, n_layers,
                                                       rot_axis)
        return qnnops.energy(ham_matrix, ansatz_state)

    grad_fn = jax.grad(loss)
    if args.use_jit:
        grad_loss_fn = jax.jit(grad_fn)

    # Collect the norms of gradients
    params, grads = [], []
    for step in tqdm(range(sample_size)):
        rng, param_rng = jax.random.split(rng)
        _, param = qnnops.initialize_circuit_params(param_rng, n_qubits,
                                                    n_layers)
        grad = grad_fn(param)
        params.append(param)
        grads.append(grad)

    params = jnp.vstack(params)
    grads = jnp.vstack(grads)
    grad_norms = jnp.linalg.norm(grads, axis=1)

    grads_all_mean, grads_all_var = jnp.mean(grads).item(), jnp.var(
        grads).item()
    grads_single_mean, grads_single_var = jnp.mean(
        grads[:, 0]).item(), jnp.var(grads[:, 0]).item()
    grads_norm_mean, grads_norm_var = jnp.mean(grad_norms).item(), jnp.var(
        grad_norms).item()
예제 #2
0
expmgr.log_array(target_state=target_state)
expmgr.save_array('target_state.npy', target_state)


def circuit(params):
    return qnnops.alternating_layer_ansatz(
        params, n_qubits=n_qubits, block_size=block_size, n_layers=n_layers, rot_axis=rot_axis)


def loss_fn(params):
    ansatz_state = circuit(params)
    return qnnops.state_norm(ansatz_state - target_state) / (2 ** n_qubits)


rng = jax.random.PRNGKey(seed)
_, init_params = qnnops.initialize_circuit_params(rng, n_qubits, n_layers)
trained_params, _ = qnnops.train_loop(
    loss_fn, init_params, args.train_steps, args.lr,
    optimizer_name=args.optimizer_name, optimizer_args=args.optimizer_args,
    scheduler_name=args.scheduler_name,
    checkpoint_path=args.checkpoint_path,
    use_jit=args.use_jit,
    use_jacfwd=args.use_jacfwd
)

optimized_state = circuit(trained_params)
expmgr.log_array(optimized_state=optimized_state)
expmgr.save_array('optimized_state.npy', optimized_state)

expmgr.save_config(args)