def build_opt_ops( self, wavefunction: wavefunctions.Wavefunction, hamiltonian: operators.Operator, hparams: tf.contrib.training.HParams, shared_resources: Dict[graph_builders.ResourceName, tf.Tensor], ) -> NamedTuple: """Adds wavefunction optimization ops to the graph. Args: wavefunction: Wavefunction ansatz to optimize. hamiltonian: Hamiltonian whose ground state we are solving for. hparams: Hyperparameters of the optimization procedure. shared_resources: Resources sharable among different modules. Returns: NamedTuple holding tensors needed to run a training epoch. """ batch_size = hparams.batch_size n_sites = hparams.num_sites configs = graph_builders.get_configs(shared_resources, batch_size, n_sites) mc_step, acc_rate = graph_builders.get_monte_carlo_sampling( shared_resources, configs, wavefunction) opt_v = wavefunction.get_trainable_variables() wf_omega = copy.deepcopy(wavefunction) # building supervisor wavefunction. beta = tf.constant(hparams.time_evolution_beta, dtype=tf.float32) beta2 = tf.constant(hparams.time_evolution_befta ** 2, dtype=tf.float32) psi_omega = wf_omega(configs) h_psi_omega = hamiltonian.apply_in_place(wf_omega, configs, psi_omega) h_psi_omega_beta = h_psi_omega * beta ite_psi_omega = psi_omega - h_psi_omega_beta local_energy = h_psi_omega / psi_omega energy_expectation = tf.reduce_mean(local_energy) squared_energy_expectation = tf.reduce_mean(tf.square(local_energy)) ite_normalization = tf.sqrt( 1. - 2 * beta * energy_expectation + squared_energy_expectation * beta2 ) ite_normalization_var = tf.get_variable( name='ite_normalization', initializer=tf.ones(shape=[]), dtype=tf.float32, trainable=False ) num_epochs = graph_builders.get_or_create_num_epochs() exp_moving_average = tf.train.ExponentialMovingAverage(0.999, num_epochs) accumulate_norm = exp_moving_average.apply([ite_normalization]) normalization_value = exp_moving_average.average(ite_normalization) accumulate_energy = exp_moving_average.apply([energy_expectation]) energy_value = exp_moving_average.average(energy_expectation) update_normalization = tf.assign(ite_normalization_var, normalization_value) psi = wavefunction(configs) loss = tf.reduce_mean( tf.squared_difference(psi, ite_psi_omega / ite_normalization_var) / tf.square(tf.stop_gradient(psi)) ) optimizer = create_sgd_optimizer(hparams) train_step = optimizer.minimize(loss, var_list=opt_v) train_step = tf.group([train_step, accumulate_norm, accumulate_energy]) update_supervisor = wavefunctions.module_transfer_ops( wavefunction, wf_omega) epoch_increment = tf.assign_add(num_epochs, 1) update_wf_norm = wavefunction.update_norm(psi) train_ops = TrainOpsSWO( train_step=train_step, accumulate_gradients=None, apply_gradients=None, reset_gradients=None, mc_step=mc_step, acc_rate=acc_rate, metrics=loss, energy=energy_value, update_supervisor=update_supervisor, update_normalization=update_normalization, epoch_increment=epoch_increment, update_wf_norm=update_wf_norm, ) return train_ops
def build_opt_ops( self, wavefunction: wavefunctions.Wavefunction, hamiltonian: operators.Operator, hparams: tf.contrib.training.HParams, shared_resources: Dict[graph_builders.ResourceName, tf.Tensor], ) -> NamedTuple: """Adds wavefunction optimization ops to the graph. Args: wavefunction: Wavefunction ansatz to optimize. hamiltonian: Hamiltonian whose ground state we are solving for. hparams: Hyperparameters of the optimization procedure. shared_resources: Resources sharable among different modules. Returns: NamedTuple holding tensors needed to run a training epoch. """ batch_size = hparams.batch_size n_sites = hparams.num_sites configs = graph_builders.get_configs(shared_resources, batch_size, n_sites) mc_step, acc_rate = graph_builders.get_monte_carlo_sampling( shared_resources, configs, wavefunction) opt_v = wavefunction.get_trainable_variables() wf_omega = copy.deepcopy(wavefunction) # building supervisor wavefunction. psi = wavefunction(configs) psi_omega = wf_omega(configs) beta = tf.constant(hparams.time_evolution_beta, dtype=tf.float32) h_psi_omega = hamiltonian.apply_in_place(wf_omega, configs, psi_omega) h_psi_omega_beta = beta * h_psi_omega ite_psi_omega = psi_omega - h_psi_omega_beta local_energy = h_psi_omega / psi_omega update_wf_norm = wavefunction.update_norm(psi) psi_no_grad = tf.stop_gradient(psi) ratio = tf.stop_gradient(ite_psi_omega / psi) log_psi_raw_grads = tf.gradients(psi / psi_no_grad, opt_v) log_psi_grads = [tf.convert_to_tensor(grad) for grad in log_psi_raw_grads] ratio_log_psi_raw_grads = tf.gradients(ratio * psi / psi_no_grad, opt_v) ratio_log_psi_grads = [ tf.convert_to_tensor(grad) for grad in ratio_log_psi_raw_grads ] log_grads = [ tf.metrics.mean_tensor(log_psi_grad) for log_psi_grad in log_psi_grads ] ratio_grads = [ tf.metrics.mean_tensor(grad) for grad in ratio_log_psi_grads ] mean_energy, accumulate_energy = tf.metrics.mean(local_energy) mean_ratio, accumulate_ratio = tf.metrics.mean(ratio) mean_log_grads, accumulate_log_grads = list(map(list, zip(*log_grads))) mean_ratio_grads, accumulate_ratio_grads = list( map(list, zip(*ratio_grads))) grad_pairs = zip(mean_log_grads, mean_ratio_grads) overlap_gradients = [ grad - scaled_grad / mean_ratio for grad, scaled_grad in grad_pairs ] grads_and_vars = list(zip(overlap_gradients, opt_v)) optimizer = create_sgd_optimizer(hparams) apply_gradients = optimizer.apply_gradients(grads_and_vars) all_updates = [accumulate_ratio, accumulate_energy] all_updates += accumulate_log_grads + accumulate_ratio_grads accumulate_gradients = tf.group(all_updates) update_network = wavefunctions.module_transfer_ops(wavefunction, wf_omega) clear_gradients = tf.variables_initializer(tf.local_variables()) num_epochs = graph_builders.get_or_create_num_epochs() epoch_increment = tf.assign_add(num_epochs, 1) train_ops = TrainOpsSWO( train_step=None, accumulate_gradients=accumulate_gradients, apply_gradients=apply_gradients, reset_gradients=clear_gradients, mc_step=mc_step, acc_rate=acc_rate, metrics=None, energy=mean_energy, update_supervisor=update_network, update_normalization=None, epoch_increment=epoch_increment, update_wf_norm=update_wf_norm, ) return train_ops