Exemplo n.º 1
0
  def test_ferminet_size(self, size):
    atoms = [
        system.Atom(symbol='H', coords=(0, 0, -1.0)),
        system.Atom(symbol='H', coords=(0, 0, 1.0)),
    ]
    ne = (size, size)
    nout = 1
    batch_size = 3
    ndet = 4
    hidden_units = [[8, 8], [8, 8]]
    after_det = [8, 8, nout]
    x = tf.random_normal([batch_size, 3 * sum(ne)])

    ferminet = networks.FermiNet(
        atoms=atoms,
        nelectrons=ne,
        slater_dets=ndet,
        hidden_units=hidden_units,
        after_det=after_det)
    y = ferminet(x)
    for i in range(len(ne)):
      self.assertEqual(ferminet._dets[i].shape.as_list(), [batch_size, ndet])
      self.assertEqual(ferminet._orbitals[i].shape.as_list(),
                       [batch_size, ndet, size, size])
    self.assertEqual(y.shape.as_list(), [batch_size, nout])
Exemplo n.º 2
0
  def test_ferminet(self, hidden_units, after_det):
    """Check that FermiNet is actually antisymmetric."""
    atoms = [
        system.Atom(symbol='H', coords=(0, 0, -1.0)),
        system.Atom(symbol='H', coords=(0, 0, 1.0)),
    ]
    ne = (3, 2)
    x1 = tf.random_normal([3, 3 * sum(ne)])
    xs = tf.split(x1, sum(ne), axis=1)

    # swap indices to test antisymmetry
    x2 = tf.concat([xs[1], xs[0]] + xs[2:], axis=1)
    x3 = tf.concat(
        xs[:ne[0]] + [xs[ne[0] + 1], xs[ne[0]]] + xs[ne[0] + 2:], axis=1)

    ferminet = networks.FermiNet(
        atoms=atoms,
        nelectrons=ne,
        slater_dets=4,
        hidden_units=hidden_units,
        after_det=after_det)
    y1 = ferminet(x1)
    y2 = ferminet(x2)
    y3 = ferminet(x3)
    with tf.train.MonitoredSession() as session:
      out1, out2, out3 = session.run([y1, y2, y3])
      np.testing.assert_allclose(out1, -out2, rtol=4.e-5, atol=1.e-6)
      np.testing.assert_allclose(out1, -out3, rtol=4.e-5, atol=1.e-6)
Exemplo n.º 3
0
  def test_ferminet_pretrain(self, hidden_units, after_det):
    """Check that FermiNet pretraining runs."""
    atoms = [
        system.Atom(symbol='Li', coords=(0, 0, -1.0)),
        system.Atom(symbol='Li', coords=(0, 0, 1.0)),
    ]
    ne = (4, 2)
    x = tf.random_normal([1, 10, 3 * sum(ne)])

    strategy = tf.distribute.get_strategy()

    with strategy.scope():
      ferminet = networks.FermiNet(
          atoms=atoms,
          nelectrons=ne,
          slater_dets=4,
          hidden_units=hidden_units,
          after_det=after_det,
          pretrain_iterations=10)

    # Test Hartree fock pretraining - no change of position.
    hf_approx = scf.Scf(atoms, nelectrons=ne)
    hf_approx.run()
    pretrain_op_hf = networks.pretrain_hartree_fock(ferminet, x, strategy,
                                                    hf_approx)
    self.assertEqual(ferminet.pretrain_iterations, 10)
    with tf.train.MonitoredSession() as session:
      for _ in range(ferminet.pretrain_iterations):
        session.run(pretrain_op_hf)
Exemplo n.º 4
0
  def test_ferminet_mask(self):
    """Check that FermiNet with a decaying mask on the output works."""
    atoms = [
        system.Atom(symbol='H', coords=(0, 0, -1.0)),
        system.Atom(symbol='H', coords=(0, 0, 1.0)),
    ]
    ne = (3, 2)
    hidden_units = [[8, 8], [8, 8]]
    after_det = [8, 8, 2]
    x = tf.random_normal([3, 3 * sum(ne)])

    ferminet = networks.FermiNet(
        atoms=atoms,
        nelectrons=ne,
        slater_dets=4,
        hidden_units=hidden_units,
        after_det=after_det,
        envelope=True)
    y = ferminet(x)
    with tf.train.MonitoredSession() as session:
      session.run(y)
Exemplo n.º 5
0
def train(molecule: Sequence[system.Atom],
          spins: Tuple[int, int],
          batch_size: int,
          network_config: Optional[NetworkConfig] = None,
          pretrain_config: Optional[PretrainConfig] = None,
          optim_config: Optional[OptimConfig] = None,
          kfac_config: Optional[KfacConfig] = None,
          mcmc_config: Optional[MCMCConfig] = None,
          logging_config: Optional[LoggingConfig] = None,
          multi_gpu: bool = False,
          double_precision: bool = False,
          graph_path: Optional[str] = None):
  """Configures and runs training loop.

  Args:
    molecule: molecule description.
    spins: pair of ints specifying number of spin-up and spin-down electrons
      respectively.
    batch_size: batch size. Also referred to as the number of Markov Chain Monte
      Carlo configurations/walkers.
    network_config: network configuration. Default settings in NetworkConfig are
      used if not specified.
    pretrain_config: pretraining configuration. Default settings in
      PretrainConfig are used if not specified.
    optim_config: optimization configuration. Default settings in OptimConfig
      are used if not specified.
    kfac_config: K-FAC configuration. Default settings in KfacConfig are used if
      not specified.
    mcmc_config: Markov Chain Monte Carlo configuration. Default settings in
      MCMCConfig are used if not specified.
    logging_config: logging and checkpoint configuration. Default settings in
      LoggingConfig are used if not specified.
    multi_gpu: Use all available GPUs. Default: use only a single GPU.
    double_precision: use tf.float64 instead of tf.float32 for all operations.
      Warning - double precision is not currently functional with K-FAC.
    graph_path: directory to save a representation of the TF graph to. Not saved

  Raises:
    RuntimeError: if mcmc_config.init_means is supplied but is of the incorrect
    length.
  """

  if not mcmc_config:
    mcmc_config = MCMCConfig()
  if not logging_config:
    logging_config = LoggingConfig()
  if not pretrain_config:
    pretrain_config = PretrainConfig()
  if not optim_config:
    optim_config = OptimConfig()
  if not kfac_config:
    kfac_config = KfacConfig()
  if not network_config:
    network_config = NetworkConfig()

  nelectrons = sum(spins)
  precision = tf.float64 if double_precision else tf.float32

  if multi_gpu:
    strategy = tf.distribute.MirroredStrategy()
  else:
    # Get the default (single-device) strategy.
    strategy = tf.distribute.get_strategy()
  if multi_gpu:
    batch_size = batch_size // strategy.num_replicas_in_sync
    logging.info('Setting per-GPU batch size to %s.', batch_size)
    logging_config.replicas = strategy.num_replicas_in_sync
  logging.info('Running on %s replicas.', strategy.num_replicas_in_sync)

  # Create a re-entrant variable scope for network.
  with tf.variable_scope('model') as model:
    pass

  with strategy.scope():
    with tf.variable_scope(model, auxiliary_name_scope=False) as model1:
      with tf.name_scope(model1.original_name_scope):
        fermi_net = networks.FermiNet(
            atoms=molecule,
            nelectrons=spins,
            slater_dets=network_config.determinants,
            hidden_units=network_config.hidden_units,
            after_det=network_config.after_det,
            architecture=network_config.architecture,
            r12_ee_features=network_config.r12_ee_features,
            r12_en_features=network_config.r12_en_features,
            pos_ee_features=network_config.pos_ee_features,
            build_backflow=network_config.build_backflow,
            use_backflow=network_config.backflow,
            jastrow_en=network_config.jastrow_en,
            jastrow_ee=network_config.jastrow_ee,
            jastrow_een=network_config.jastrow_een,
            logdet=True,
            envelope=network_config.use_envelope,
            residual=network_config.residual,
            pretrain_iterations=pretrain_config.iterations)

    scf_approx = scf.Scf(
        molecule,
        nelectrons=spins,
        restricted=False,
        basis=pretrain_config.basis)
    if pretrain_config.iterations > 0:
      scf_approx.run()

    hamiltonian_ops = hamiltonian.operators(molecule, nelectrons)
    if mcmc_config.init_means:
      if len(mcmc_config.init_means) != 3 * nelectrons:
        raise RuntimeError('Initial electron positions of incorrect shape. '
                           '({} not {})'.format(
                               len(mcmc_config.init_means), 3 * nelectrons))
      init_means = [float(x) for x in mcmc_config.init_means]
    else:
      init_means = assign_electrons(molecule, spins)

    # Build the MCMC state inside the same variable scope as the network.
    with tf.variable_scope(model, auxiliary_name_scope=False) as model1:
      with tf.name_scope(model1.original_name_scope):
        data_gen = mcmc.MCMC(
            fermi_net,
            batch_size,
            init_mu=init_means,
            init_sigma=mcmc_config.init_width,
            move_sigma=mcmc_config.move_width,
            dtype=precision)
    with tf.variable_scope('HF_data_gen'):
      hf_data_gen = mcmc.MCMC(
          scf_approx.tf_eval_slog_hartree_product,
          batch_size,
          init_mu=init_means,
          init_sigma=mcmc_config.init_width,
          move_sigma=mcmc_config.move_width,
          dtype=precision)

    with tf.name_scope('learning_rate_schedule'):
      global_step = tf.train.get_or_create_global_step()
      lr = optim_config.learning_rate * tf.pow(
          (1.0 / (1.0 + (tf.cast(global_step, tf.float32) /
                         optim_config.learning_rate_delay))),
          optim_config.learning_rate_decay)

    if optim_config.learning_rate < 1.e-10:
      logging.warning('Learning rate less than 10^-10. Not using an optimiser.')
      optim_fn = lambda _: None
      update_cached_data = None
    elif optim_config.use_kfac:
      cached_data = tf.get_variable(
          'MCMC_cache',
          initializer=tf.zeros(shape=data_gen.walkers.shape, dtype=precision),
          use_resource=True,
          trainable=False,
          dtype=precision,
      )
      if kfac_config.adapt_damping:
        update_cached_data = tf.assign(cached_data, data_gen.walkers)
      else:
        update_cached_data = None
      optim_fn = lambda layer_collection: mean_corrected_kfac_opt.MeanCorrectedKfacOpt(  # pylint: disable=g-long-lambda
          invert_every=kfac_config.invert_every,
          cov_update_every=kfac_config.cov_update_every,
          learning_rate=lr,
          norm_constraint=kfac_config.norm_constraint,
          damping=kfac_config.damping,
          cov_ema_decay=kfac_config.cov_ema_decay,
          momentum=kfac_config.momentum,
          momentum_type=kfac_config.momentum_type,
          loss_fn=lambda x: tf.nn.l2_loss(fermi_net(x)[0]),
          train_batch=data_gen.walkers,
          prev_train_batch=cached_data,
          layer_collection=layer_collection,
          batch_size=batch_size,
          adapt_damping=kfac_config.adapt_damping,
          is_chief=True,
          damping_adaptation_decay=kfac_config.damping_adaptation_decay,
          damping_adaptation_interval=kfac_config.damping_adaptation_interval,
          min_damping=kfac_config.min_damping,
          use_passed_loss=False,
          estimation_mode='exact',
      )
    else:
      adam = tf.train.AdamOptimizer(lr)
      optim_fn = lambda _: adam
      update_cached_data = None

    qmc_net = qmc.QMC(
        hamiltonian_ops,
        fermi_net,
        data_gen,
        hf_data_gen,
        clip_el=optim_config.clip_el,
        check_loss=optim_config.check_loss,
    )

  qmc_net.train(
      optim_fn,
      optim_config.iterations,
      logging_config,
      using_kfac=optim_config.use_kfac,
      strategy=strategy,
      scf_approx=scf_approx,
      global_step=global_step,
      determinism_mode=optim_config.deterministic,
      cached_data_op=update_cached_data,
      write_graph=os.path.abspath(graph_path) if graph_path else None,
      burn_in=mcmc_config.burn_in,
      mcmc_steps=mcmc_config.steps,
  )
Exemplo n.º 6
0
def main(fixed_coords_up, fixed_coords_down, mcmc_step_std, init_sigma,
         batch_size, iterations):
    checkpoint_path = "Neon/checkpoints"

    latest = tf.train.latest_checkpoint(checkpoint_path)

    ferminet = networks.FermiNet(atoms=molecule,
                                 nelectrons=spins,
                                 slater_dets=16,
                                 hidden_units=((256, 32), ) * 4,
                                 after_det=(1, ),
                                 pretrain_iterations=0,
                                 logdet=True,
                                 envelope=True,
                                 residual=True,
                                 name="model/det_net")

    init_means = train.assign_electrons(molecule, spins)
    batch = np.concatenate([
        np.random.normal(
            size=(batch_size, 1),
            loc=mu,
            scale=init_sigma,
        ) for mu in init_means
    ],
                           axis=-1)

    batch, mask = fix_coords(batch, fixed_coords_up, fixed_coords_down, spins)

    print("mask: ", mask)
    psi = ferminet(batch)[0]

    saver = tf.train.Saver(max_to_keep=10, save_relative_paths=True)

    with tf.Session() as sess:
        saver.restore(sess, latest)
        sess.run(psi)

        psi = sess.run(psi)

        #	for itr in range(iterations):
        #
        #		new_batch = batch + tf.random.normal(shape = batch.shape, stddev = mcmc_step_std)*mask
        #		update_psi = ferminet(new_batch)[0]
        #		new_psi = sess.run(update_psi)
        #		pmove = tf.squeeze(2*(new_psi-psi))
        #		pacc = tf.log(tf.random_uniform(shape = batch.shape.as_list()[:1]))
        #		decision = tf.less(pacc, pmove)
        #		with tf.control_dependencies([decision]):
        #			new_batch = tf.where(decision, new_batch, batch)
        #			new_psi = tf.where(decision, new_psi, psi)
        #		move_acc = tf.reduce_mean(tf.cast(decision, tf.float32))
        #		batch = new_batch
        #		psi = new_psi
        #		print (sess.run(psi), sess.run(move_acc))

        print(type(batch.shape))
        print("Burn in")
        for itr in range(20):
            new_batch, new_psi, move_acc = sess.run(
                mcmc_step(batch, psi, ferminet, mask))
            batch = tf.constant(new_batch, (tf.float32).base_dtype)
            psi = new_psi
            print(itr, " pacc: ", move_acc)
        print("Burn in done")
        h5_schema = {"walkers": batch.shape.as_list()}
        with H5Writer(name="neon_data_z.h5", schema=h5_schema,
                      directory="./") as h5_writer:
            for itr in range(iterations):
                for _ in range(2):
                    new_batch, new_psi, move_acc = sess.run(
                        mcmc_step(batch, psi, ferminet, mask))
                    batch = tf.constant(new_batch, (tf.float32).base_dtype)
                    psi = new_psi
                print(itr, " pacc: ", move_acc)
                out = {"walkers": new_batch}
                h5_writer.write(itr, out)