def JaxDeepConvNN(hilbert, hamiltonian, alpha=1, optimizer='Sgd', lr=0.1, sampler='Local'): """Complex deep convolutional Neural Network Machine implemented in Jax. Conv1d, complexReLU, Conv1d, complexReLU, Conv1d, complexReLU, Conv1d, complexReLU, Dense, complexReLU, Dense Args: hilbert (netket.hilbert) : hilbert space hamiltonian (netket.hamiltonian) : hamiltonian alpha (int) : hidden layer density optimizer (str) : possible choices are 'Sgd', 'Adam', or 'AdaMax' lr (float) : learning rate sampler (str) : possible choices are 'Local', 'Exact', 'VBS', 'Inverse' Returns: ma (netket.machine) : machine op (netket.optimizer) : optimizer sa (netket.sampler) : sampler machine_name (str) : name of the machine, see get_operator """ print('JaxDeepConvNN is used') input_size = hilbert.size init_fun, apply_fun = stax.serial(FixSrLayer, InputForConvLayer, Conv1d(alpha, (3,)), ComplexReLu, Conv1d(alpha, (3,)), ComplexReLu, Conv1d(alpha, (3,)), ComplexReLu, Conv1d(alpha, (3,)), ComplexReLu, stax.Flatten, Dense(input_size * alpha), ComplexReLu, Dense(1), FormatLayer) ma = nk.machine.Jax( hilbert, (init_fun, apply_fun), dtype=complex ) ma.init_random_parameters(seed=12, sigma=0.01) # Optimizer if (optimizer == 'Sgd'): op = Wrap(ma, SgdJax(lr)) elif (optimizer == 'Adam'): op = Wrap(ma, AdamJax(lr)) else: op = Wrap(ma, AdaMaxJax(lr)) # Sampler if (sampler == 'Local'): sa = nk.sampler.MetropolisLocal(machine=ma) elif (sampler == 'Exact'): sa = nk.sampler.ExactSampler(machine=ma) elif (sampler == 'VBS'): sa = my_sampler.getVBSSampler(machine=ma) elif (sampler == 'Inverse'): sa = my_sampler.getInverseSampler(machine=ma) else: sa = nk.sampler.MetropolisHamiltonian(machine=ma, hamiltonian=hamiltonian, n_chains=16) machine_name = 'JaxDeepConvNN' return ma, op, sa, machine_name
def JaxTransformedFFNN(hilbert, hamiltonian, alpha=1, optimizer='Sgd', lr=0.1, sampler='Local'): """Complex Feed Forward Neural Network (fully connected) Machine implemented in Jax. One hidden layer. The input data is transformed in the beginning by the transformation 10.1103/physrevb.46.3486 Dense, ComplexReLU, Dense Args: hilbert (netket.hilbert) : hilbert space hamiltonian (netket.hamiltonian) : hamiltonian alpha (int) : hidden layer density optimizer (str) : possible choices are 'Sgd', 'Adam', or 'AdaMax' lr (float) : learning rate sampler (str) : possible choices are 'Local', 'Exact', 'VBS', 'Inverse' Returns: ma (netket.machine) : machine op (netket.optimizer) : optimizer sa (netket.sampler) : sampler machine_name (str) : name of the machine, see get_operator """ print('JaxTransformedFFNN is used') input_size = hilbert.size init_fun, apply_fun = stax.serial(FixSrLayer, TransformedLayer, Dense(input_size * alpha), ComplexReLu, Dense(1), FormatLayer) ma = nk.machine.Jax( hilbert, (init_fun, apply_fun), dtype=complex ) ma.init_random_parameters(seed=12, sigma=0.01) # Optimizer if (optimizer == 'Sgd'): op = Wrap(ma, SgdJax(lr)) elif (optimizer == 'Adam'): op = Wrap(ma, AdamJax(lr)) else: op = Wrap(ma, AdaMaxJax(lr)) # Sampler if (sampler == 'Local'): sa = nk.sampler.MetropolisLocal(machine=ma) elif (sampler == 'Exact'): sa = nk.sampler.ExactSampler(machine=ma) elif(sampler == 'VBS'): sa = my_sampler.getVBSSampler(machine=ma) elif (sampler == 'Inverse'): sa = my_sampler.getInverseSampler(machine=ma) else: sa = nk.sampler.MetropolisHamiltonian(machine=ma, hamiltonian=hamiltonian, n_chains=16) machine_name = 'JaxTransformedFFNN' return ma, op, sa, machine_name
def JaxUnaryRBM(hilbert, hamiltonian, alpha=1, optimizer='Sgd', lr=0.1, sampler='Local'): """Complex unary Restricted Boltzmann Machine implemented in Jax. UnaryLayer, Dense, LogCosh, Sum Args: hilbert (netket.hilbert) : hilbert space hamiltonian (netket.hamiltonian) : hamiltonian alpha (int) : hidden layer density optimizer (str) : possible choices are 'Sgd', 'Adam', or 'AdaMax' lr (float) : learning rate sampler (str) : possible choices are 'Local', 'Exact', 'VBS', 'Inverse' Returns: ma (netket.machine) : machine op (netket.optimizer) : optimizer sa (netket.sampler) : sampler machine_name (str) : name of the machine, see get_operator """ print('JaxUnaryRBM is used') input_size = hilbert.size ma = nk.machine.Jax( hilbert, stax.serial(FixSrLayer, UnaryLayer, stax.Dense(alpha * input_size), LogCoshLayer, SumLayer), dtype=complex ) ma.init_random_parameters(seed=12, sigma=0.01) # Optimizer if(optimizer == 'Sgd'): op = Wrap(ma, SgdJax(lr)) elif(optimizer == 'Adam'): op = Wrap(ma, AdamJax(lr)) else: op = Wrap(ma, AdaMaxJax(lr)) # Sampler if(sampler == 'Local'): sa = nk.sampler.MetropolisLocal(machine=ma) elif (sampler == 'Exact'): sa = nk.sampler.ExactSampler(machine=ma) elif (sampler == 'VBS'): sa = my_sampler.getVBSSampler(machine=ma) elif (sampler == 'Inverse'): sa = my_sampler.getInverseSampler(machine=ma) else: sa = nk.sampler.MetropolisHamiltonian(machine=ma, hamiltonian=hamiltonian, n_chains=16) machine_name = 'JaxUnaryRBM' return ma, op, sa, machine_name
def load_machine(machine, hamiltonian, optimizer='Sgd', lr=0.1, sampler='Local'): """Function to get an operator and sampler for a loaded machine. The machine is not loaded in this method! The machine is not returned -> Syntax is a bit different than in the other functions. Only works with Jax-machines so far. Args: machine (netket.machine) : loaded machine hamiltonian (netket.hamiltonian) : hamiltonian optimizer (str) : possible choices are 'Sgd', 'Adam', or 'AdaMax' lr (float) : learning rate sampler (str) : possible choices are 'Local', 'Exact', 'VBS', 'Inverse' Returns: op (netket.optimizer) : optimizer sa (netket.sampler) : sampler """ ma = machine # Optimizer if (optimizer == 'Sgd'): op = Wrap(ma, SgdJax(lr)) elif (optimizer == 'Adam'): op = Wrap(ma, AdamJax(lr)) else: op = Wrap(ma, AdaMaxJax(lr)) # Sampler if (sampler == 'Local'): sa = nk.sampler.MetropolisLocal(machine=ma) elif (sampler == 'Exact'): sa = nk.sampler.ExactSampler(machine=ma) elif (sampler == 'VBS'): sa = my_sampler.getVBSSampler(machine=ma) elif (sampler == 'Inverse'): sa = my_sampler.getInverseSampler(machine=ma) else: sa = nk.sampler.MetropolisHamiltonian(machine=ma, hamiltonian=hamiltonian, n_chains=16) return op, sa