Exemplo n.º 1
0
def JaxDeepConvNN(hilbert, hamiltonian, alpha=1, optimizer='Sgd', lr=0.1, sampler='Local'):
    """Complex deep convolutional Neural Network Machine implemented in Jax.
        Conv1d, complexReLU, Conv1d, complexReLU, Conv1d, complexReLU,
        Conv1d, complexReLU, Dense, complexReLU, Dense

            Args:
                hilbert (netket.hilbert) : hilbert space
                hamiltonian (netket.hamiltonian) : hamiltonian
                alpha (int) : hidden layer density
                optimizer (str) : possible choices are 'Sgd', 'Adam', or 'AdaMax'
                lr (float) : learning rate
                sampler (str) : possible choices are 'Local', 'Exact', 'VBS', 'Inverse'

            Returns:
                ma (netket.machine) : machine
                op (netket.optimizer) : optimizer
                sa (netket.sampler) : sampler
                machine_name (str) : name of the machine, see get_operator
                                                    """
    print('JaxDeepConvNN is used')
    input_size = hilbert.size
    init_fun, apply_fun = stax.serial(FixSrLayer, InputForConvLayer, Conv1d(alpha, (3,)), ComplexReLu,
                                      Conv1d(alpha, (3,)), ComplexReLu, Conv1d(alpha, (3,)), ComplexReLu,
                                      Conv1d(alpha, (3,)), ComplexReLu, stax.Flatten,
                                      Dense(input_size * alpha), ComplexReLu, Dense(1), FormatLayer)
    ma = nk.machine.Jax(
        hilbert,
        (init_fun, apply_fun), dtype=complex
    )
    ma.init_random_parameters(seed=12, sigma=0.01)
    # Optimizer
    if (optimizer == 'Sgd'):
        op = Wrap(ma, SgdJax(lr))
    elif (optimizer == 'Adam'):
        op = Wrap(ma, AdamJax(lr))
    else:
        op = Wrap(ma, AdaMaxJax(lr))
    # Sampler
    if (sampler == 'Local'):
        sa = nk.sampler.MetropolisLocal(machine=ma)
    elif (sampler == 'Exact'):
        sa = nk.sampler.ExactSampler(machine=ma)
    elif (sampler == 'VBS'):
        sa = my_sampler.getVBSSampler(machine=ma)
    elif (sampler == 'Inverse'):
        sa = my_sampler.getInverseSampler(machine=ma)
    else:
        sa = nk.sampler.MetropolisHamiltonian(machine=ma, hamiltonian=hamiltonian, n_chains=16)
    machine_name = 'JaxDeepConvNN'
    return ma, op, sa, machine_name
Exemplo n.º 2
0
def JaxTransformedFFNN(hilbert, hamiltonian, alpha=1, optimizer='Sgd', lr=0.1, sampler='Local'):
    """Complex Feed Forward Neural Network (fully connected) Machine implemented in Jax. One hidden layer.

        The input data is transformed in the beginning by the transformation 10.1103/physrevb.46.3486
        Dense, ComplexReLU, Dense

            Args:
                hilbert (netket.hilbert) : hilbert space
                hamiltonian (netket.hamiltonian) : hamiltonian
                alpha (int) : hidden layer density
                optimizer (str) : possible choices are 'Sgd', 'Adam', or 'AdaMax'
                lr (float) : learning rate
                sampler (str) : possible choices are 'Local', 'Exact', 'VBS', 'Inverse'

            Returns:
                ma (netket.machine) : machine
                op (netket.optimizer) : optimizer
                sa (netket.sampler) : sampler
                machine_name (str) : name of the machine, see get_operator
                                                """
    print('JaxTransformedFFNN is used')
    input_size = hilbert.size
    init_fun, apply_fun = stax.serial(FixSrLayer, TransformedLayer,
        Dense(input_size * alpha), ComplexReLu,
        Dense(1), FormatLayer)
    ma = nk.machine.Jax(
        hilbert,
        (init_fun, apply_fun), dtype=complex
    )
    ma.init_random_parameters(seed=12, sigma=0.01)
    # Optimizer
    if (optimizer == 'Sgd'):
        op = Wrap(ma, SgdJax(lr))
    elif (optimizer == 'Adam'):
        op = Wrap(ma, AdamJax(lr))
    else:
        op = Wrap(ma, AdaMaxJax(lr))
    # Sampler
    if (sampler == 'Local'):
        sa = nk.sampler.MetropolisLocal(machine=ma)
    elif (sampler == 'Exact'):
        sa = nk.sampler.ExactSampler(machine=ma)
    elif(sampler == 'VBS'):
        sa = my_sampler.getVBSSampler(machine=ma)
    elif (sampler == 'Inverse'):
        sa = my_sampler.getInverseSampler(machine=ma)
    else:
        sa = nk.sampler.MetropolisHamiltonian(machine=ma, hamiltonian=hamiltonian, n_chains=16)
    machine_name = 'JaxTransformedFFNN'
    return ma, op, sa, machine_name
Exemplo n.º 3
0
def JaxUnaryRBM(hilbert, hamiltonian, alpha=1, optimizer='Sgd', lr=0.1, sampler='Local'):
    """Complex unary Restricted Boltzmann Machine implemented in Jax.
        UnaryLayer, Dense, LogCosh, Sum

            Args:
                hilbert (netket.hilbert) : hilbert space
                hamiltonian (netket.hamiltonian) : hamiltonian
                alpha (int) : hidden layer density
                optimizer (str) : possible choices are 'Sgd', 'Adam', or 'AdaMax'
                lr (float) : learning rate
                sampler (str) : possible choices are 'Local', 'Exact', 'VBS', 'Inverse'

            Returns:
                ma (netket.machine) : machine
                op (netket.optimizer) : optimizer
                sa (netket.sampler) : sampler
                machine_name (str) : name of the machine, see get_operator

                                                """
    print('JaxUnaryRBM is used')
    input_size = hilbert.size
    ma = nk.machine.Jax(
        hilbert,
        stax.serial(FixSrLayer, UnaryLayer, stax.Dense(alpha * input_size), LogCoshLayer, SumLayer),
        dtype=complex
    )
    ma.init_random_parameters(seed=12, sigma=0.01)
    # Optimizer
    if(optimizer == 'Sgd'):
        op = Wrap(ma, SgdJax(lr))
    elif(optimizer == 'Adam'):
        op = Wrap(ma, AdamJax(lr))
    else:
        op = Wrap(ma, AdaMaxJax(lr))
    # Sampler
    if(sampler == 'Local'):
        sa = nk.sampler.MetropolisLocal(machine=ma)
    elif (sampler == 'Exact'):
        sa = nk.sampler.ExactSampler(machine=ma)
    elif (sampler == 'VBS'):
        sa = my_sampler.getVBSSampler(machine=ma)
    elif (sampler == 'Inverse'):
        sa = my_sampler.getInverseSampler(machine=ma)
    else:
        sa = nk.sampler.MetropolisHamiltonian(machine=ma, hamiltonian=hamiltonian, n_chains=16)
    machine_name = 'JaxUnaryRBM'
    return ma, op, sa, machine_name
Exemplo n.º 4
0
def load_machine(machine, hamiltonian, optimizer='Sgd', lr=0.1, sampler='Local'):
    """Function to get an operator and sampler for a loaded machine. The machine is not loaded in this method!
        The machine is not returned -> Syntax is a bit different than in the other functions.
        Only works with Jax-machines so far.

            Args:
                machine (netket.machine) : loaded machine
                hamiltonian (netket.hamiltonian) : hamiltonian
                optimizer (str) : possible choices are 'Sgd', 'Adam', or 'AdaMax'
                lr (float) : learning rate
                sampler (str) : possible choices are 'Local', 'Exact', 'VBS', 'Inverse'

            Returns:
                op (netket.optimizer) : optimizer
                sa (netket.sampler) : sampler
    """
    ma = machine
    # Optimizer
    if (optimizer == 'Sgd'):
        op = Wrap(ma, SgdJax(lr))
    elif (optimizer == 'Adam'):
        op = Wrap(ma, AdamJax(lr))
    else:
        op = Wrap(ma, AdaMaxJax(lr))
    # Sampler
    if (sampler == 'Local'):
        sa = nk.sampler.MetropolisLocal(machine=ma)
    elif (sampler == 'Exact'):
        sa = nk.sampler.ExactSampler(machine=ma)
    elif (sampler == 'VBS'):
        sa = my_sampler.getVBSSampler(machine=ma)
    elif (sampler == 'Inverse'):
        sa = my_sampler.getInverseSampler(machine=ma)
    else:
        sa = nk.sampler.MetropolisHamiltonian(machine=ma, hamiltonian=hamiltonian, n_chains=16)
    return op, sa