def test_tf_eval_slog_wavefuncs(self):

        # Check TensorFlow evaluation runs and gives correct shapes.
        molecule = [system.Atom('O', (0, 0, 0))]
        nelectrons = (5, 3)
        total_electrons = sum(nelectrons)
        num_spatial_dim = 3
        hf = scf.Scf(molecule=molecule,
                     nelectrons=nelectrons,
                     restricted=False)
        hf.run()

        batch = 100
        rng = np.random.RandomState(1)
        flat_positions_np = rng.randn(batch, total_electrons * num_spatial_dim)

        flat_positions_tf = tf.constant(flat_positions_np)
        for method in [
                hf.tf_eval_slog_slater_determinant,
                hf.tf_eval_slog_hartree_product
        ]:
            slog_wavefunc, signs = method(flat_positions_tf)
            with tf.train.MonitoredSession() as session:
                slog_wavefunc_, signs_ = session.run([slog_wavefunc, signs])
            self.assertEqual(slog_wavefunc_.shape, (batch, 1))
            self.assertEqual(signs_.shape, (batch, 1))
        hartree_product = hf.tf_eval_hartree_product(flat_positions_tf)
        with tf.train.MonitoredSession() as session:
            hartree_product_ = session.run(hartree_product)
        np.testing.assert_allclose(hartree_product_,
                                   np.exp(slog_wavefunc_) * signs_)
Esempio n. 2
0
  def test_ferminet_pretrain(self, hidden_units, after_det):
    """Check that FermiNet pretraining runs."""
    atoms = [
        system.Atom(symbol='Li', coords=(0, 0, -1.0)),
        system.Atom(symbol='Li', coords=(0, 0, 1.0)),
    ]
    ne = (4, 2)
    x = tf.random_normal([1, 10, 3 * sum(ne)])

    strategy = tf.distribute.get_strategy()

    with strategy.scope():
      ferminet = networks.FermiNet(
          atoms=atoms,
          nelectrons=ne,
          slater_dets=4,
          hidden_units=hidden_units,
          after_det=after_det,
          pretrain_iterations=10)

    # Test Hartree fock pretraining - no change of position.
    hf_approx = scf.Scf(atoms, nelectrons=ne)
    hf_approx.run()
    pretrain_op_hf = networks.pretrain_hartree_fock(ferminet, x, strategy,
                                                    hf_approx)
    self.assertEqual(ferminet.pretrain_iterations, 10)
    with tf.train.MonitoredSession() as session:
      for _ in range(ferminet.pretrain_iterations):
        session.run(pretrain_op_hf)
    def test_tf_eval_hf(self):

        # Check we get consistent answers between multiple calls to eval_mos and
        # a single tf_eval_hf call.
        molecule = [system.Atom('O', (0, 0, 0))]
        nelectrons = (5, 3)
        hf = scf.Scf(molecule=molecule,
                     nelectrons=nelectrons,
                     restricted=False)
        hf.run()

        batch = 100
        xs = [np.random.randn(batch, 3) for _ in range(sum(nelectrons))]
        mos = []
        for i, x in enumerate(xs):
            ispin = 0 if i < nelectrons[0] else 1
            orbitals = hf.eval_mos(x)[ispin]
            # Select occupied orbitals via Aufbau.
            mos.append(orbitals[:, :nelectrons[ispin]])
        np_mos = (np.stack(mos[:nelectrons[0]],
                           axis=1), np.stack(mos[nelectrons[0]:], axis=1))

        tf_xs = tf.constant(np.stack(xs, axis=1))
        tf_mos = hf.tf_eval_hf(tf_xs, deriv=True)
        with tf.train.MonitoredSession() as session:
            tf_mos_ = session.run(tf_mos)

        for i, (np_mos_mat, tf_mos_mat) in enumerate(zip(np_mos, tf_mos_)):
            self.assertEqual(np_mos_mat.shape, tf_mos_mat.shape)
            self.assertEqual(np_mos_mat.shape,
                             (batch, nelectrons[i], nelectrons[i]))
            np.testing.assert_allclose(np_mos_mat, tf_mos_mat)
def create_o2_hf(bond_length):
    molecule = [
        system.Atom('O', (0, 0, 0)),
        system.Atom('O', (0, 0, bond_length))
    ]
    spin = 2
    oxygen_atomic_number = 8
    nelectrons = [oxygen_atomic_number + spin, oxygen_atomic_number - spin]
    hf = scf.Scf(molecule=molecule, nelectrons=nelectrons, restricted=False)
    return hf
def train(molecule: Sequence[system.Atom],
          spins: Tuple[int, int],
          batch_size: int,
          network_config: Optional[NetworkConfig] = None,
          pretrain_config: Optional[PretrainConfig] = None,
          optim_config: Optional[OptimConfig] = None,
          kfac_config: Optional[KfacConfig] = None,
          mcmc_config: Optional[MCMCConfig] = None,
          logging_config: Optional[LoggingConfig] = None,
          multi_gpu: bool = False,
          double_precision: bool = False,
          graph_path: Optional[str] = None):
    """Configures and runs training loop.

  Args:
    molecule: molecule description.
    spins: pair of ints specifying number of spin-up and spin-down electrons
      respectively.
    batch_size: batch size. Also referred to as the number of Markov Chain Monte
      Carlo configurations/walkers.
    network_config: network configuration. Default settings in NetworkConfig are
      used if not specified.
    pretrain_config: pretraining configuration. Default settings in
      PretrainConfig are used if not specified.
    optim_config: optimization configuration. Default settings in OptimConfig
      are used if not specified.
    kfac_config: K-FAC configuration. Default settings in KfacConfig are used if
      not specified.
    mcmc_config: Markov Chain Monte Carlo configuration. Default settings in
      MCMCConfig are used if not specified.
    logging_config: logging and checkpoint configuration. Default settings in
      LoggingConfig are used if not specified.
    multi_gpu: Use all available GPUs. Default: use only a single GPU.
    double_precision: use tf.float64 instead of tf.float32 for all operations.
      Warning - double precision is not currently functional with K-FAC.
    graph_path: directory to save a representation of the TF graph to. Not saved

  Raises:
    RuntimeError: if mcmc_config.init_means is supplied but is of the incorrect
    length.
  """

    if not mcmc_config:
        mcmc_config = MCMCConfig()
    if not logging_config:
        logging_config = LoggingConfig()
    if not pretrain_config:
        pretrain_config = PretrainConfig()
    if not optim_config:
        optim_config = OptimConfig()
    if not kfac_config:
        kfac_config = KfacConfig()
    if not network_config:
        network_config = NetworkConfig()

    nelectrons = sum(spins)
    precision = tf.float64 if double_precision else tf.float32

    if multi_gpu:
        strategy = tf.distribute.MirroredStrategy()
    else:
        # Get the default (single-device) strategy.
        strategy = tf.distribute.get_strategy()
    if multi_gpu:
        batch_size = batch_size // strategy.num_replicas_in_sync
        logging.info('Setting per-GPU batch size to %s.', batch_size)
        logging_config.replicas = strategy.num_replicas_in_sync
    logging.info('Running on %s replicas.', strategy.num_replicas_in_sync)

    # Create a re-entrant variable scope for network.
    with tf.variable_scope('model') as model:
        pass

    with strategy.scope():
        with tf.variable_scope(model, auxiliary_name_scope=False) as model1:
            with tf.name_scope(model1.original_name_scope):
                fermi_net = networks.FermiNet(
                    atoms=molecule,
                    nelectrons=spins,
                    slater_dets=network_config.determinants,
                    hidden_units=network_config.hidden_units,
                    after_det=network_config.after_det,
                    architecture=network_config.architecture,
                    r12_ee_features=network_config.r12_ee_features,
                    r12_en_features=network_config.r12_en_features,
                    pos_ee_features=network_config.pos_ee_features,
                    build_backflow=network_config.build_backflow,
                    use_backflow=network_config.backflow,
                    jastrow_en=network_config.jastrow_en,
                    jastrow_ee=network_config.jastrow_ee,
                    jastrow_een=network_config.jastrow_een,
                    logdet=True,
                    envelope=network_config.use_envelope,
                    residual=network_config.residual,
                    pretrain_iterations=pretrain_config.iterations)

        scf_approx = scf.Scf(molecule,
                             nelectrons=spins,
                             restricted=False,
                             basis=pretrain_config.basis)
        if pretrain_config.iterations > 0:
            scf_approx.run()

        hamiltonian_ops = hamiltonian.operators(molecule, nelectrons)
        if mcmc_config.init_means:
            if len(mcmc_config.init_means) != 3 * nelectrons:
                raise RuntimeError(
                    'Initial electron positions of incorrect shape. '
                    '({} not {})'.format(len(mcmc_config.init_means),
                                         3 * nelectrons))
            init_means = [float(x) for x in mcmc_config.init_means]
        else:
            init_means = assign_electrons(molecule, spins)

        # Build the MCMC state inside the same variable scope as the network.
        with tf.variable_scope(model, auxiliary_name_scope=False) as model1:
            with tf.name_scope(model1.original_name_scope):
                data_gen = mcmc.MCMC(fermi_net,
                                     batch_size,
                                     init_mu=init_means,
                                     init_sigma=mcmc_config.init_width,
                                     move_sigma=mcmc_config.move_width,
                                     dtype=precision)
        with tf.variable_scope('HF_data_gen'):
            hf_data_gen = mcmc.MCMC(scf_approx.tf_eval_slog_hartree_product,
                                    batch_size,
                                    init_mu=init_means,
                                    init_sigma=mcmc_config.init_width,
                                    move_sigma=mcmc_config.move_width,
                                    dtype=precision)

        with tf.name_scope('learning_rate_schedule'):
            global_step = tf.train.get_or_create_global_step()
            lr = optim_config.learning_rate * tf.pow(
                (1.0 / (1.0 + (tf.cast(global_step, tf.float32) /
                               optim_config.learning_rate_delay))),
                optim_config.learning_rate_decay)

        if optim_config.learning_rate < 1.e-10:
            logging.warning(
                'Learning rate less than 10^-10. Not using an optimiser.')
            optim_fn = lambda _: None
            update_cached_data = None
        elif optim_config.use_kfac:
            cached_data = tf.get_variable(
                'MCMC_cache',
                initializer=tf.zeros(shape=data_gen.walkers.shape,
                                     dtype=precision),
                use_resource=True,
                trainable=False,
                dtype=precision,
            )
            if kfac_config.adapt_damping:
                update_cached_data = tf.assign(cached_data, data_gen.walkers)
            else:
                update_cached_data = None
            optim_fn = lambda layer_collection: mean_corrected_kfac_opt.MeanCorrectedKfacOpt(  # pylint: disable=g-long-lambda
                invert_every=kfac_config.invert_every,
                cov_update_every=kfac_config.cov_update_every,
                learning_rate=lr,
                norm_constraint=kfac_config.norm_constraint,
                damping=kfac_config.damping,
                cov_ema_decay=kfac_config.cov_ema_decay,
                momentum=kfac_config.momentum,
                momentum_type=kfac_config.momentum_type,
                loss_fn=lambda x: tf.nn.l2_loss(fermi_net(x)[0]),
                train_batch=data_gen.walkers,
                prev_train_batch=cached_data,
                layer_collection=layer_collection,
                batch_size=batch_size,
                adapt_damping=kfac_config.adapt_damping,
                is_chief=True,
                damping_adaptation_decay=kfac_config.damping_adaptation_decay,
                damping_adaptation_interval=kfac_config.
                damping_adaptation_interval,
                min_damping=kfac_config.min_damping,
                use_passed_loss=False,
                estimation_mode='exact',
            )
        else:
            adam = tf.train.AdamOptimizer(lr)
            optim_fn = lambda _: adam
            update_cached_data = None

        qmc_net = qmc.QMC(
            hamiltonian_ops,
            fermi_net,
            data_gen,
            hf_data_gen,
            clip_el=optim_config.clip_el,
            check_loss=optim_config.check_loss,
        )

    qmc_net.train(
        optim_fn,
        optim_config.iterations,
        logging_config,
        using_kfac=optim_config.use_kfac,
        strategy=strategy,
        scf_approx=scf_approx,
        global_step=global_step,
        determinism_mode=optim_config.deterministic,
        cached_data_op=update_cached_data,
        write_graph=os.path.abspath(graph_path) if graph_path else None,
        burn_in=mcmc_config.burn_in,
        mcmc_steps=mcmc_config.steps,
    )