예제 #1
0
def save_checkpoints(tag, step, w1, w2, params, history, optimizer_state):
    bends = jnp.vstack([w1, params, w2])
    expmgr.save_array(f'bends_{tag}.npy', bends)
    expmgr.save_array(f'bends_{tag}.npy', bends)
    qnnops.save_checkpoint(f'checkpoint_{tag}.pkl', step, optimizer_state,
                           history)
예제 #2
0
        grads).item()
    grads_single_mean, grads_single_var = jnp.mean(
        grads[:, 0]).item(), jnp.var(grads[:, 0]).item()
    grads_norm_mean, grads_norm_var = jnp.mean(grad_norms).item(), jnp.var(
        grad_norms).item()

    logging_output = OrderedDict(grad_component_all_mean=grads_all_mean,
                                 grad_component_all_var=grads_all_var,
                                 grad_component_single_mean=grads_single_mean,
                                 grad_component_single_var=grads_single_var,
                                 grad_norm_mean=grads_norm_mean,
                                 grad_norm_var=grads_norm_var)

    expmgr.log(step=n_layers, logging_output=logging_output)

    wandb.log(dict(
        grad_component_all=wandb.Histogram(
            np_histogram=jnp.histogram(grads, bins=64, density=True)),
        grad_component_single=wandb.Histogram(
            np_histogram=jnp.histogram(grads[:, 0], bins=64, density=True)),
        grad_norm=wandb.Histogram(
            np_histogram=jnp.histogram(grad_norms, bins=64, density=True))),
              step=n_layers)

    suffix = f'Q{n_qubits}L{n_layers}R{rot_axis}BS{block_size}_g{g}h{h}'
    expmgr.save_array(f'params_{suffix}.npy', params)
    expmgr.save_array(f'grads_{suffix}.npy', grads)

    del params, grads
    gc.collect()
예제 #3
0
                    help='Omit the time tag from experiment name.')
parser.add_argument('--use-jacfwd', dest='use_jacfwd', action='store_true',
                    help='Enable the forward mode gradient computation (jacfwd).')
parser.add_argument('--version', type=int, default=1, choices=[1, 2],
                    help='qnnops version (Default: 1)')

args = parser.parse_args()
seed = args.seed
n_qubits, n_layers, rot_axis = args.n_qubits, args.n_layers, args.rot_axis
block_size = args.n_qubits
exp_name = args.exp_name or f'Q{n_qubits}L{n_layers}R{rot_axis}BS{block_size} - S{seed} - LR{args.lr}'
expmgr.init(project='expressibility', name=exp_name, config=args)

target_state = qnnops.create_target_states(n_qubits, 1, seed=seed)
expmgr.log_array(target_state=target_state)
expmgr.save_array('target_state.npy', target_state)


def circuit(params):
    return qnnops.alternating_layer_ansatz(
        params, n_qubits=n_qubits, block_size=block_size, n_layers=n_layers, rot_axis=rot_axis)


def loss_fn(params):
    ansatz_state = circuit(params)
    return qnnops.state_norm(ansatz_state - target_state) / (2 ** n_qubits)


rng = jax.random.PRNGKey(seed)
_, init_params = qnnops.initialize_circuit_params(rng, n_qubits, n_layers)
trained_params, _ = qnnops.train_loop(
예제 #4
0
def main():
    parser = argparse.ArgumentParser('Mode Connectivity')
    parser.add_argument('--n-qubits',
                        type=int,
                        metavar='N',
                        required=True,
                        help='Number of qubits')
    parser.add_argument('--n-layers',
                        type=int,
                        metavar='N',
                        required=True,
                        help='Number of alternating layers')
    parser.add_argument('--rot-axis',
                        type=str,
                        metavar='R',
                        required=True,
                        choices=['x', 'y', 'z'],
                        help='Direction of rotation gates.')
    parser.add_argument('--g',
                        type=float,
                        metavar='M',
                        required=True,
                        help='Transverse magnetic field')
    parser.add_argument('--h',
                        type=float,
                        metavar='M',
                        required=True,
                        help='Longitudinal magnetic field')
    parser.add_argument('--n-bends',
                        type=int,
                        metavar='N',
                        default=3,
                        help='Number of bends between endpoints')
    parser.add_argument('--train-steps',
                        type=int,
                        metavar='N',
                        default=int(1e3),
                        help='Number of training steps. (Default: 1000)')
    parser.add_argument('--batch-size',
                        type=int,
                        metavar='N',
                        default=8,
                        help='Batch size. (Default: 8)')
    parser.add_argument('--lr',
                        type=float,
                        metavar='LR',
                        default=0.05,
                        help='Initial value of learning rate. (Default: 0.05)')
    parser.add_argument('--log-every',
                        type=int,
                        metavar='N',
                        default=1,
                        help='Logging every N steps. (Default: 1)')
    parser.add_argument(
        '--seed',
        type=int,
        metavar='N',
        required=True,
        help='Random seed. For reproducibility, the value is set explicitly.')
    parser.add_argument('--model-seeds',
                        type=int,
                        metavar='N',
                        nargs=2,
                        required=True,
                        help='Random seed used for model training.')
    parser.add_argument('--exp-name',
                        type=str,
                        metavar='NAME',
                        default=None,
                        help='Experiment name.')
    parser.add_argument(
        '--scheduler-name',
        type=str,
        metavar='NAME',
        default='exponential_decay',
        help=f'Scheduler name. Supports: {qnnops.supported_schedulers()} '
        f'(Default: constant)')
    parser.add_argument(
        '--params-start',
        type=str,
        metavar='PATH',
        help='A file path of a checkpoint where the curve starts from')
    parser.add_argument(
        '--params-end',
        type=str,
        metavar='PATH',
        help='A file path of a checkpoint where the curve ends')
    parser.add_argument('--no-jit',
                        dest='use_jit',
                        action='store_false',
                        help='Disable jit option to loss function.')
    parser.add_argument('--no-time-tag',
                        dest='time_tag',
                        action='store_false',
                        help='Omit the time tag from experiment name.')
    parser.add_argument('--quiet',
                        action='store_true',
                        help='Quite mode (No training logs)')
    args = parser.parse_args()

    n_qubits, n_layers = args.n_qubits, args.n_layers
    if args.exp_name is None:
        args.exp_name = f'Q{n_qubits}L{n_layers}_nB{args.n_bends}'

    if args.model_seeds is None:
        params_start, params_end = args.params_start, args.params_end
    else:
        params_start, params_end = download_checkpoints(
            'checkpoints',
            n_qubits=n_qubits,
            n_layers=n_layers,
            lr=0.05,  # fixed lr that we used in training
            seed={'$in': args.model_seeds})

    print('Initializing project')
    expmgr.init('ModeConnectivity', args.exp_name, args)

    print('Loading pretrained models')
    w1 = jnp.load(params_start)
    w2 = jnp.load(params_end)
    expmgr.save_array('endpoint_begin.npy', w1)
    expmgr.save_array('endpoint_end.npy', w2)

    print('Constructing Hamiltonian matrix')
    ham_matrix = qnnops.ising_hamiltonian(n_qubits=n_qubits,
                                          g=args.g,
                                          h=args.h)

    print('Define the loss function')

    def loss_fn(params):
        ansatz_state = qnnops.alternating_layer_ansatz(params, n_qubits,
                                                       n_qubits, n_layers,
                                                       args.rot_axis)
        return qnnops.energy(ham_matrix, ansatz_state)

    find_connected_curve(w1,
                         w2,
                         loss_fn,
                         n_bends=args.n_bends,
                         train_steps=args.train_steps,
                         lr=args.lr,
                         scheduler_name=args.scheduler_name,
                         batch_size=args.batch_size,
                         log_every=1,
                         seed=args.seed,
                         use_jit=args.use_jit)
예제 #5
0
def train_loop(loss_fn,
               init_params,
               train_steps=int(1e4),
               lr=0.01,
               optimizer_name='adam',
               optimizer_args=None,
               scheduler_name='constant',
               loss_args=None,
               early_stopping=False,
               monitor=None,
               log_every=1,
               checkpoint_path=None,
               use_jit=True,
               use_jacfwd=False):
    """ Training loop.

    Args:
        loss_fn: callable, loss function whose first argument must be params.
        init_params: jnp.array, initial trainable parameter values
        train_steps: int, total number of training steps
        lr: float, initial learning rate
        optimizer_name: str, optimizer name to be used.
        optimizer_args: dict, custom arguments for the optimizer.
            If None, default arguments will be used.
        scheduler_name: str, scheduler name.
        loss_args: dict, additional loss arguments if needed.
        early_stopping: bool, whether to early stop if the train loss value
            doesn't decrease further. (Not implemented yet)
        monitor: callable -> dict, monitoring function on training.
        log_every: int, logging every N steps.
        checkpoint_path: str, a checkpoint file path to resume.
        use_jit: bool, whether to use jit compilation.
        use_jacfwd: bool, enable the forward mode jax.jacfwd for gradient
            computation instead the reverse mode jax.grad (jax.jacrev)
            For backward compatibility, this option disables by default.
            But, later it will enable. (Default: False)
    Returns:
        params: jnp.array, optimized parameters
        history: dict, training history.
    """

    assert monitor is None or callable(
        monitor), 'the monitoring function must be callable.'

    loss_args = loss_args or {}
    train_steps = int(train_steps)  # to guarantee an integer type value.
    scheduler = get_scheduler(lr, train_steps, scheduler_name)
    init_fun, update_fun, get_params = get_optimizer(optimizer_name,
                                                     optimizer_args, scheduler)
    if checkpoint_path:
        start_step, optimizer_state, history = load_checkpoint(checkpoint_path)
        min_loss = jnp.hstack(history['loss']).min()
    else:
        start_step = 0
        optimizer_state = init_fun(init_params)
        history = {'loss': [], 'grad': [], 'params': []}
        min_loss = float('inf')

    try:
        if use_jacfwd:
            grad_loss_fn = jax.jacfwd(loss_fn)
        else:
            grad_loss_fn = jax.value_and_grad(loss_fn)
        if use_jit:
            grad_loss_fn = jax.jit(grad_loss_fn)
        for step in range(start_step, train_steps):
            params = get_params(optimizer_state)
            if use_jacfwd:
                loss = loss_fn(params, **loss_args)
                grad = grad_loss_fn(params, **loss_args)
            else:
                # for backward compatibility. It will be replaced by
                # jax.grad to make the consistency with jax.jacfwd.
                loss, grad = grad_loss_fn(params, **loss_args)
            optimizer_state = update_fun(step, grad, optimizer_state)
            updated_params = get_params(optimizer_state)

            grad = onp.array(grad)
            params = onp.array(params)
            updated_params = onp.array(updated_params)
            history['loss'].append(loss)
            history['grad'].append(grad)
            history['params'].append(params)
            if loss < min_loss:
                min_loss = loss
                expmgr.save_array('params_best.npy', updated_params)
                save_checkpoint('checkpoint_best.pkl', step, optimizer_state,
                                history)

            if step % log_every == 0:
                grad_norm = jnp.linalg.norm(grad).item()
                logging_output = OrderedDict(loss=loss.item(),
                                             lr=scheduler(step),
                                             grad_norm=grad_norm)
                if monitor is not None:
                    logging_output.update(monitor(params=params))
                logging_output['min_loss'] = min_loss.item()
                expmgr.log(step, logging_output)
                expmgr.save_array('params_last.npy', updated_params)
                save_checkpoint('checkpoint_last.pkl', step, optimizer_state,
                                history)

            if early_stopping:
                # TODO(jdk): implement early stopping feature.
                pass
            del loss, grad
            gc.collect()

    except Exception as e:
        print(e)
        print('Saving history object...')
        expmgr.save_history('history.npz', history)
        raise e
    else:
        expmgr.save_history('history.npz', history)
    return get_params(optimizer_state), history
예제 #6
0
                    help='qnnops version (Default: 1)')
args = parser.parse_args()

seed, seed_SYK = args.seed, args.seed_SYK
n_qubits, max_n_layers, rot_axis = args.n_qubits, args.max_n_layers, args.rot_axis
block_size = args.n_qubits
sample_size = args.sample_size

if not args.exp_name:
    args.exp_name = f'SYK4 - Q{n_qubits}R{rot_axis}BS{block_size} -  SYK{seed_SYK} - S{seed} - SN{sample_size}'
expmgr.init(project='SYK4BP', name=args.exp_name, config=args)

# Construct the hamiltonian matrix of Ising model.
ham_matrix = qnnops.SYK_hamiltonian(jax.random.PRNGKey(args.seed_SYK),
                                    n_qubits)
expmgr.save_array('hamiltonian_matrix.npy', ham_matrix, upload_to_wandb=True)

rng = jax.random.PRNGKey(seed)  # Set of random seeds for parameter sampling

M = int(log2(max_n_layers))
for i in range(1, M + 1):
    n_layers = 2**i
    print(f'{n_qubits} Qubits & {n_layers} Layers ({i}/{M})')

    def loss(_params):
        ansatz_state = qnnops.alternating_layer_ansatz(_params, n_qubits,
                                                       block_size, n_layers,
                                                       rot_axis)
        return qnnops.energy(ham_matrix, ansatz_state)

    grad_fn = jax.grad(loss)
예제 #7
0
        temp = jnp.kron(temp, qnnops.PauliBasis[0])
        
    gamma_matrices.append(temp)

# Number of SYK4 interaction terms
n_terms = int(factorial(n_gamma) / factorial(4) / factorial(n_gamma - 4)) 

# SYK4 random coupling
couplings = jax.random.normal(key=jax.random.PRNGKey(args.seed_SYK),
                              shape=(n_terms, ), dtype=jnp.float64) * jnp.sqrt(6 / (n_gamma ** 3))

ham_matrix = 0
for idx, (x, y, w, z) in enumerate(combinations(range(n_gamma), 4)):
    ham_matrix += (couplings[idx] / 4) * jnp.linalg.multi_dot([gamma_matrices[x], gamma_matrices[y], gamma_matrices[w], gamma_matrices[z]])

expmgr.save_array('hamiltonian_matrix.npy', ham_matrix, upload_to_wandb=False)

eigval, eigvec = jnp.linalg.eigh(ham_matrix)
eigvec = eigvec.T  # Transpose such that eigvec[i] is an eigenvector, rather than eigenftn[:, i]
ground_state = eigvec[0]
next_to_ground_state = eigvec[1]

print("The lowest eigenvalues (energy) and corresponding eigenvectors (state)")
for i in range(min(5, len(eigval))):
    print(f'| {i}-th state energy={eigval[i]:.4f}')
    print(f'| {i}-th state vector={eigvec[i]}')
expmgr.log_array(
    eigenvalues=eigval,
    ground_state=ground_state,
    next_to_ground_state=next_to_ground_state,
)
# Set of random seeds for parameter sampling
rng = jax.random.PRNGKey(seed)

# Collect the norms of gradients
grads = []


for step in range(sample_size):
    rng, param_rng = jax.random.split(rng)
    _, init_params = qnnops.initialize_circuit_params(param_rng, n_qubits, n_layers)
    grads.append(jax.grad(loss)(init_params))
    wandb.log({'step': step})

grads = jnp.vstack(grads)

grads_mean, grads_var, grads_norm = jnp.mean(grads, axis=0), jnp.var(grads, axis=0), jnp.linalg.norm(grads, axis=1)
expmgr.save_array(expmgr.get_result_path('grads_mean.npy'), grads_mean)
expmgr.save_array(expmgr.get_result_path('grads_var.npy'), grads_var)
expmgr.save_array(expmgr.get_result_path('grads_norm.npy'), grads_norm)

wandb.config.grads_mean = str(grads_mean)
wandb.config.grads_var = str(grads_var)
wandb.config.grads_norm = str(grads_norm)


wandb.log({'means_mean': jnp.mean(grads_mean).item(),
           'means_var': jnp.var(grads_mean).item(),
           'vars_mean': jnp.mean(grads_var).item(),
           'vars_var': jnp.var(grads_var).item(),  
           'norms_mean': jnp.mean(grads_norm).item(),
           'norms_var': jnp.var(grads_norm).item()})
# Collect the norms of gradients
grads = []

for step in range(sample_size):
    rng, param_rng = jax.random.split(rng)
    _, init_params = qnnops.initialize_circuit_params(param_rng, n_qubits,
                                                      n_layers)
    grads.append(jax.grad(loss)(init_params))
    wandb.log({'step': step})

grads = jnp.vstack(grads)

grads_mean, grads_var, grads_norm = jnp.mean(grads, axis=0), jnp.var(
    grads, axis=0), jnp.linalg.norm(grads, axis=1)
expmgr.save_array('grads_mean.npy', grads_mean)
expmgr.save_array('grads_var.npy', grads_var)
expmgr.save_array('grads_norm.npy', grads_norm)

wandb.config.grads_mean = str(grads_mean)
wandb.config.grads_var = str(grads_var)
wandb.config.grads_norm = str(grads_norm)

wandb.log({
    'means_mean': jnp.mean(grads_mean).item(),
    'means_var': jnp.var(grads_mean).item(),
    'vars_mean': jnp.mean(grads_var).item(),
    'vars_var': jnp.var(grads_var).item(),
    'norms_mean': jnp.mean(grads_norm).item(),
    'norms_var': jnp.var(grads_norm).item()
})