def test_ferminet_size(self, size): atoms = [ system.Atom(symbol='H', coords=(0, 0, -1.0)), system.Atom(symbol='H', coords=(0, 0, 1.0)), ] ne = (size, size) nout = 1 batch_size = 3 ndet = 4 hidden_units = [[8, 8], [8, 8]] after_det = [8, 8, nout] x = tf.random_normal([batch_size, 3 * sum(ne)]) ferminet = networks.FermiNet( atoms=atoms, nelectrons=ne, slater_dets=ndet, hidden_units=hidden_units, after_det=after_det) y = ferminet(x) for i in range(len(ne)): self.assertEqual(ferminet._dets[i].shape.as_list(), [batch_size, ndet]) self.assertEqual(ferminet._orbitals[i].shape.as_list(), [batch_size, ndet, size, size]) self.assertEqual(y.shape.as_list(), [batch_size, nout])
def test_ferminet(self, hidden_units, after_det): """Check that FermiNet is actually antisymmetric.""" atoms = [ system.Atom(symbol='H', coords=(0, 0, -1.0)), system.Atom(symbol='H', coords=(0, 0, 1.0)), ] ne = (3, 2) x1 = tf.random_normal([3, 3 * sum(ne)]) xs = tf.split(x1, sum(ne), axis=1) # swap indices to test antisymmetry x2 = tf.concat([xs[1], xs[0]] + xs[2:], axis=1) x3 = tf.concat( xs[:ne[0]] + [xs[ne[0] + 1], xs[ne[0]]] + xs[ne[0] + 2:], axis=1) ferminet = networks.FermiNet( atoms=atoms, nelectrons=ne, slater_dets=4, hidden_units=hidden_units, after_det=after_det) y1 = ferminet(x1) y2 = ferminet(x2) y3 = ferminet(x3) with tf.train.MonitoredSession() as session: out1, out2, out3 = session.run([y1, y2, y3]) np.testing.assert_allclose(out1, -out2, rtol=4.e-5, atol=1.e-6) np.testing.assert_allclose(out1, -out3, rtol=4.e-5, atol=1.e-6)
def test_ferminet_pretrain(self, hidden_units, after_det): """Check that FermiNet pretraining runs.""" atoms = [ system.Atom(symbol='Li', coords=(0, 0, -1.0)), system.Atom(symbol='Li', coords=(0, 0, 1.0)), ] ne = (4, 2) x = tf.random_normal([1, 10, 3 * sum(ne)]) strategy = tf.distribute.get_strategy() with strategy.scope(): ferminet = networks.FermiNet( atoms=atoms, nelectrons=ne, slater_dets=4, hidden_units=hidden_units, after_det=after_det, pretrain_iterations=10) # Test Hartree fock pretraining - no change of position. hf_approx = scf.Scf(atoms, nelectrons=ne) hf_approx.run() pretrain_op_hf = networks.pretrain_hartree_fock(ferminet, x, strategy, hf_approx) self.assertEqual(ferminet.pretrain_iterations, 10) with tf.train.MonitoredSession() as session: for _ in range(ferminet.pretrain_iterations): session.run(pretrain_op_hf)
def test_ferminet_mask(self): """Check that FermiNet with a decaying mask on the output works.""" atoms = [ system.Atom(symbol='H', coords=(0, 0, -1.0)), system.Atom(symbol='H', coords=(0, 0, 1.0)), ] ne = (3, 2) hidden_units = [[8, 8], [8, 8]] after_det = [8, 8, 2] x = tf.random_normal([3, 3 * sum(ne)]) ferminet = networks.FermiNet( atoms=atoms, nelectrons=ne, slater_dets=4, hidden_units=hidden_units, after_det=after_det, envelope=True) y = ferminet(x) with tf.train.MonitoredSession() as session: session.run(y)
def train(molecule: Sequence[system.Atom], spins: Tuple[int, int], batch_size: int, network_config: Optional[NetworkConfig] = None, pretrain_config: Optional[PretrainConfig] = None, optim_config: Optional[OptimConfig] = None, kfac_config: Optional[KfacConfig] = None, mcmc_config: Optional[MCMCConfig] = None, logging_config: Optional[LoggingConfig] = None, multi_gpu: bool = False, double_precision: bool = False, graph_path: Optional[str] = None): """Configures and runs training loop. Args: molecule: molecule description. spins: pair of ints specifying number of spin-up and spin-down electrons respectively. batch_size: batch size. Also referred to as the number of Markov Chain Monte Carlo configurations/walkers. network_config: network configuration. Default settings in NetworkConfig are used if not specified. pretrain_config: pretraining configuration. Default settings in PretrainConfig are used if not specified. optim_config: optimization configuration. Default settings in OptimConfig are used if not specified. kfac_config: K-FAC configuration. Default settings in KfacConfig are used if not specified. mcmc_config: Markov Chain Monte Carlo configuration. Default settings in MCMCConfig are used if not specified. logging_config: logging and checkpoint configuration. Default settings in LoggingConfig are used if not specified. multi_gpu: Use all available GPUs. Default: use only a single GPU. double_precision: use tf.float64 instead of tf.float32 for all operations. Warning - double precision is not currently functional with K-FAC. graph_path: directory to save a representation of the TF graph to. Not saved Raises: RuntimeError: if mcmc_config.init_means is supplied but is of the incorrect length. """ if not mcmc_config: mcmc_config = MCMCConfig() if not logging_config: logging_config = LoggingConfig() if not pretrain_config: pretrain_config = PretrainConfig() if not optim_config: optim_config = OptimConfig() if not kfac_config: kfac_config = KfacConfig() if not network_config: network_config = NetworkConfig() nelectrons = sum(spins) precision = tf.float64 if double_precision else tf.float32 if multi_gpu: strategy = tf.distribute.MirroredStrategy() else: # Get the default (single-device) strategy. strategy = tf.distribute.get_strategy() if multi_gpu: batch_size = batch_size // strategy.num_replicas_in_sync logging.info('Setting per-GPU batch size to %s.', batch_size) logging_config.replicas = strategy.num_replicas_in_sync logging.info('Running on %s replicas.', strategy.num_replicas_in_sync) # Create a re-entrant variable scope for network. with tf.variable_scope('model') as model: pass with strategy.scope(): with tf.variable_scope(model, auxiliary_name_scope=False) as model1: with tf.name_scope(model1.original_name_scope): fermi_net = networks.FermiNet( atoms=molecule, nelectrons=spins, slater_dets=network_config.determinants, hidden_units=network_config.hidden_units, after_det=network_config.after_det, architecture=network_config.architecture, r12_ee_features=network_config.r12_ee_features, r12_en_features=network_config.r12_en_features, pos_ee_features=network_config.pos_ee_features, build_backflow=network_config.build_backflow, use_backflow=network_config.backflow, jastrow_en=network_config.jastrow_en, jastrow_ee=network_config.jastrow_ee, jastrow_een=network_config.jastrow_een, logdet=True, envelope=network_config.use_envelope, residual=network_config.residual, pretrain_iterations=pretrain_config.iterations) scf_approx = scf.Scf( molecule, nelectrons=spins, restricted=False, basis=pretrain_config.basis) if pretrain_config.iterations > 0: scf_approx.run() hamiltonian_ops = hamiltonian.operators(molecule, nelectrons) if mcmc_config.init_means: if len(mcmc_config.init_means) != 3 * nelectrons: raise RuntimeError('Initial electron positions of incorrect shape. ' '({} not {})'.format( len(mcmc_config.init_means), 3 * nelectrons)) init_means = [float(x) for x in mcmc_config.init_means] else: init_means = assign_electrons(molecule, spins) # Build the MCMC state inside the same variable scope as the network. with tf.variable_scope(model, auxiliary_name_scope=False) as model1: with tf.name_scope(model1.original_name_scope): data_gen = mcmc.MCMC( fermi_net, batch_size, init_mu=init_means, init_sigma=mcmc_config.init_width, move_sigma=mcmc_config.move_width, dtype=precision) with tf.variable_scope('HF_data_gen'): hf_data_gen = mcmc.MCMC( scf_approx.tf_eval_slog_hartree_product, batch_size, init_mu=init_means, init_sigma=mcmc_config.init_width, move_sigma=mcmc_config.move_width, dtype=precision) with tf.name_scope('learning_rate_schedule'): global_step = tf.train.get_or_create_global_step() lr = optim_config.learning_rate * tf.pow( (1.0 / (1.0 + (tf.cast(global_step, tf.float32) / optim_config.learning_rate_delay))), optim_config.learning_rate_decay) if optim_config.learning_rate < 1.e-10: logging.warning('Learning rate less than 10^-10. Not using an optimiser.') optim_fn = lambda _: None update_cached_data = None elif optim_config.use_kfac: cached_data = tf.get_variable( 'MCMC_cache', initializer=tf.zeros(shape=data_gen.walkers.shape, dtype=precision), use_resource=True, trainable=False, dtype=precision, ) if kfac_config.adapt_damping: update_cached_data = tf.assign(cached_data, data_gen.walkers) else: update_cached_data = None optim_fn = lambda layer_collection: mean_corrected_kfac_opt.MeanCorrectedKfacOpt( # pylint: disable=g-long-lambda invert_every=kfac_config.invert_every, cov_update_every=kfac_config.cov_update_every, learning_rate=lr, norm_constraint=kfac_config.norm_constraint, damping=kfac_config.damping, cov_ema_decay=kfac_config.cov_ema_decay, momentum=kfac_config.momentum, momentum_type=kfac_config.momentum_type, loss_fn=lambda x: tf.nn.l2_loss(fermi_net(x)[0]), train_batch=data_gen.walkers, prev_train_batch=cached_data, layer_collection=layer_collection, batch_size=batch_size, adapt_damping=kfac_config.adapt_damping, is_chief=True, damping_adaptation_decay=kfac_config.damping_adaptation_decay, damping_adaptation_interval=kfac_config.damping_adaptation_interval, min_damping=kfac_config.min_damping, use_passed_loss=False, estimation_mode='exact', ) else: adam = tf.train.AdamOptimizer(lr) optim_fn = lambda _: adam update_cached_data = None qmc_net = qmc.QMC( hamiltonian_ops, fermi_net, data_gen, hf_data_gen, clip_el=optim_config.clip_el, check_loss=optim_config.check_loss, ) qmc_net.train( optim_fn, optim_config.iterations, logging_config, using_kfac=optim_config.use_kfac, strategy=strategy, scf_approx=scf_approx, global_step=global_step, determinism_mode=optim_config.deterministic, cached_data_op=update_cached_data, write_graph=os.path.abspath(graph_path) if graph_path else None, burn_in=mcmc_config.burn_in, mcmc_steps=mcmc_config.steps, )
def main(fixed_coords_up, fixed_coords_down, mcmc_step_std, init_sigma, batch_size, iterations): checkpoint_path = "Neon/checkpoints" latest = tf.train.latest_checkpoint(checkpoint_path) ferminet = networks.FermiNet(atoms=molecule, nelectrons=spins, slater_dets=16, hidden_units=((256, 32), ) * 4, after_det=(1, ), pretrain_iterations=0, logdet=True, envelope=True, residual=True, name="model/det_net") init_means = train.assign_electrons(molecule, spins) batch = np.concatenate([ np.random.normal( size=(batch_size, 1), loc=mu, scale=init_sigma, ) for mu in init_means ], axis=-1) batch, mask = fix_coords(batch, fixed_coords_up, fixed_coords_down, spins) print("mask: ", mask) psi = ferminet(batch)[0] saver = tf.train.Saver(max_to_keep=10, save_relative_paths=True) with tf.Session() as sess: saver.restore(sess, latest) sess.run(psi) psi = sess.run(psi) # for itr in range(iterations): # # new_batch = batch + tf.random.normal(shape = batch.shape, stddev = mcmc_step_std)*mask # update_psi = ferminet(new_batch)[0] # new_psi = sess.run(update_psi) # pmove = tf.squeeze(2*(new_psi-psi)) # pacc = tf.log(tf.random_uniform(shape = batch.shape.as_list()[:1])) # decision = tf.less(pacc, pmove) # with tf.control_dependencies([decision]): # new_batch = tf.where(decision, new_batch, batch) # new_psi = tf.where(decision, new_psi, psi) # move_acc = tf.reduce_mean(tf.cast(decision, tf.float32)) # batch = new_batch # psi = new_psi # print (sess.run(psi), sess.run(move_acc)) print(type(batch.shape)) print("Burn in") for itr in range(20): new_batch, new_psi, move_acc = sess.run( mcmc_step(batch, psi, ferminet, mask)) batch = tf.constant(new_batch, (tf.float32).base_dtype) psi = new_psi print(itr, " pacc: ", move_acc) print("Burn in done") h5_schema = {"walkers": batch.shape.as_list()} with H5Writer(name="neon_data_z.h5", schema=h5_schema, directory="./") as h5_writer: for itr in range(iterations): for _ in range(2): new_batch, new_psi, move_acc = sess.run( mcmc_step(batch, psi, ferminet, mask)) batch = tf.constant(new_batch, (tf.float32).base_dtype) psi = new_psi print(itr, " pacc: ", move_acc) out = {"walkers": new_batch} h5_writer.write(itr, out)