Exemplo n.º 1
0
    def build_eval_ops(
        self,
        wavefunction: wavefunctions.Wavefunction,
        operator: operators.Operator,
        hparams: tf.contrib.training.HParams,
        shared_resources: Dict[graph_builders.ResourceName, Any],
    ) -> EvalOps:
        """Adds wavefunction evaluation ops to the graph.

    Args:
      wavefunction: Wavefunction ansatz to evalutate.
      operator: Operator corresponding to the value we want to evaluate.
      hparams: Hyperparameters of the evaluation procedure.
      shared_resources: System resources shared among different modules.

    Returns:
      NamedTuple holding tensors needed to run evaluation.
    """
        batch_size = hparams.batch_size
        n_sites = hparams.num_sites

        configs = graph_builders.get_configs(shared_resources, batch_size,
                                             n_sites)
        mc_step, acc_rate = graph_builders.get_monte_carlo_sampling(
            shared_resources, configs, wavefunction)

        value = tf.reduce_mean(operator.local_value(wavefunction, configs))
        eval_ops = EvalOps(
            value=value,
            mc_step=mc_step,
            acceptance_rate=acc_rate,
            placeholder_input=None,
            wavefunction_value=None,
        )
        return eval_ops
Exemplo n.º 2
0
  def build_opt_ops(
      self,
      wavefunction: wavefunctions.Wavefunction,
      hamiltonian: operators.Operator,
      hparams: tf.contrib.training.HParams,
      shared_resources: Dict[graph_builders.ResourceName, tf.Tensor],
  ) -> NamedTuple:
    """Adds wavefunction optimization ops to the graph.

    Args:
      wavefunction: Wavefunction ansatz to optimize.
      hamiltonian: Hamiltonian whose ground state we are solving for.
      hparams: Hyperparameters of the optimization procedure.
      shared_resources: Resources sharable among different modules.

    Returns:
      NamedTuple holding tensors needed to run a training epoch.
    """
    batch_size = hparams.batch_size
    n_sites = hparams.num_sites

    configs = graph_builders.get_configs(shared_resources, batch_size, n_sites)
    mc_step, acc_rate = graph_builders.get_monte_carlo_sampling(
        shared_resources, configs, wavefunction)
    opt_v = wavefunction.get_trainable_variables()

    psi = wavefunction(configs)
    psi_no_grad = tf.stop_gradient(psi)
    update_wf_norm = wavefunction.update_norm(psi)
    local_energy = hamiltonian.local_value(wavefunction, configs, psi)
    local_energy = tf.stop_gradient(local_energy)

    log_psi_raw_grads = tf.gradients(psi / psi_no_grad, opt_v)
    log_psi_grads = [tf.convert_to_tensor(grad) for grad in log_psi_raw_grads]
    e_psi_raw_grads = tf.gradients(psi / psi_no_grad * local_energy, opt_v)
    e_psi_grads = [tf.convert_to_tensor(grad) for grad in e_psi_raw_grads]

    grads = [
        tf.metrics.mean_tensor(log_psi_grad) for log_psi_grad in log_psi_grads
    ]
    weighted_grads = [tf.metrics.mean_tensor(grad) for grad in e_psi_grads]

    mean_energy, update_energy = tf.metrics.mean(local_energy)
    mean_pure_grads, update_pure_grads = list(map(list, zip(*grads)))
    mean_scaled_grads, update_scaled_grads = list(
        map(list, zip(*weighted_grads)))

    grad_pairs = zip(mean_pure_grads, mean_scaled_grads)

    energy_gradients = [
        scaled_grad - mean_energy * grad for grad, scaled_grad in grad_pairs
    ]
    grads_and_vars = list(zip(energy_gradients, opt_v))
    optimizer = create_sgd_optimizer(hparams)
    apply_gradients = optimizer.apply_gradients(grads_and_vars)
    reset_gradients = tf.variables_initializer(tf.local_variables())

    all_updates = [update_energy,] + update_pure_grads + update_scaled_grads
    accumulate_gradients = tf.group(all_updates)

    num_epochs = graph_builders.get_or_create_num_epochs()
    epoch_increment = tf.assign_add(num_epochs, 1)

    train_ops = TrainOpsTraditional(
        accumulate_gradients=accumulate_gradients,
        apply_gradients=apply_gradients,
        reset_gradients=reset_gradients,
        mc_step=mc_step,
        acc_rate=acc_rate,
        metrics=mean_energy,
        epoch_increment=epoch_increment,
        update_wf_norm=update_wf_norm,
    )
    return train_ops