optimizer = Adam(lr=0.001, beta_1=0.9, beta_2=0.999)
model.compile(optimizer=optimizer, loss=loss_for_energy_minimization)
model.summary()
operator = Ising(h=3.0, hilbert_state_shape=hilbert_state_shape, pbc=False)
sampler = AutoregressiveSampler(conditional_log_probs_model, batch_size)
monte_carlo_generator = VariationalMonteCarlo(model, operator, sampler)

validation_sampler = AutoregressiveSampler(conditional_log_probs_model,
                                           batch_size * 16)
validation_generator = VariationalMonteCarlo(model, operator,
                                             validation_sampler)

tensorboard = TensorBoardWithGeneratorValidationData(
    log_dir='tensorboard_logs/invariant_example_2d_monte_carlo_batch_%s_run_3'
    % batch_size,
    generator=monte_carlo_generator,
    update_freq=1,
    histogram_freq=1,
    batch_size=batch_size,
    write_output=False)
callbacks = default_wave_function_stats_callbacks_factory(
    monte_carlo_generator,
    validation_generator=validation_generator,
    true_ground_state_energy=-50.18662388277671) + [tensorboard]
model.fit_generator(monte_carlo_generator,
                    steps_per_epoch=steps_per_epoch,
                    epochs=2,
                    callbacks=callbacks,
                    max_queue_size=0,
                    workers=0)
model.save_weights('final_2d_ising_fcnn.h5')
示例#2
0
              loss=loss_for_energy_minimization,
              metrics=optimizer.metrics)
model.summary()
operator = Heisenberg(hilbert_state_shape=hilbert_state_shape, pbc=True)
sampler = MetropolisHastingsHamiltonian(
    model,
    batch_size,
    operator,
    num_of_chains=20,
    unused_sampels=numpy.prod(hilbert_state_shape))
variational_monte_carlo = VariationalMonteCarlo(model, operator, sampler)

tensorboard = TensorBoardWithGeneratorValidationData(
    log_dir='tensorboard_logs/rbm_with_sr_run_6',
    generator=variational_monte_carlo,
    update_freq=1,
    histogram_freq=1,
    batch_size=batch_size,
    write_output=False)
callbacks = default_wave_function_stats_callbacks_factory(
    variational_monte_carlo, true_ground_state_energy=-35.6175461195) + [
        MCMCStats(variational_monte_carlo), tensorboard
    ]
model.fit_generator(variational_monte_carlo.to_generator(),
                    steps_per_epoch=steps_per_epoch,
                    epochs=1,
                    callbacks=callbacks,
                    max_queue_size=0,
                    workers=0)
model.save_weights('final_1d_heisenberg.h5')
示例#3
0
文件: train.py 项目: xusky69/FlowKet
def train(operator, config, true_ground_state_energy=None):
    if config.use_horovod:
        init_horovod()
    to_valid_stages_config(config)
    is_rank_0 = (not config.use_horovod) or hvd.rank() == 0
    if is_rank_0:
        save_config(config)
    model, sampler = build_model(operator, config)
    optimizer = compile_model(model, config.learning_rate[0],
                              config.use_horovod)
    checkpoint_path = os.path.join(config.output_path, 'model.h5')
    initial_epoch = load_weights_if_exist(model, checkpoint_path)

    total_epochs = 0

    for idx, (batch_size, num_epoch, learning_rate) in enumerate(
            zip(config.batch_size, config.num_epoch, config.learning_rate)):
        total_epochs += num_epoch
        if total_epochs <= initial_epoch:
            continue
        vmc_cls = VariationalMonteCarlo
        if config.use_horovod:
            batch_size = int(math.ceil(batch_size / hvd.size()))
            vmc_cls = HorovodVariationalMonteCarlo

        validation_sampler = sampler.copy_with_new_batch_size(
            min(batch_size * 8, 2**15), mini_batch_size=config.mini_batch_size)
        assert batch_size < config.mini_batch_size or batch_size % config.mini_batch_size == 0
        sampler = sampler.copy_with_new_batch_size(batch_size,
                                                   config.mini_batch_size)
        variational_monte_carlo = vmc_cls(
            model, operator, sampler, mini_batch_size=config.mini_batch_size)
        validation_generator = vmc_cls(
            model,
            operator,
            validation_sampler,
            wave_function_evaluation_batch_size=config.mini_batch_size)
        optimizer.set_update_params_frequency(
            variational_monte_carlo.update_params_frequency)
        K.set_value(optimizer.lr, learning_rate)

        callbacks = default_wave_function_stats_callbacks_factory(
            variational_monte_carlo,
            validation_generator=validation_generator,
            log_in_batch_or_epoch=False,
            true_ground_state_energy=true_ground_state_energy,
            validation_period=config.validation_period)

        if config.use_horovod:
            callbacks = [
                hvd.callbacks.BroadcastGlobalVariablesCallback(0),
            ] + callbacks + [hvd.callbacks.MetricAverageCallback()]
        if is_rank_0:
            tensorboard = TensorBoardWithGeneratorValidationData(
                log_dir=config.output_path,
                generator=variational_monte_carlo,
                update_freq='epoch',
                histogram_freq=0,
                batch_size=batch_size,
                write_output=False,
                write_graph=False)
            callbacks += [
                tensorboard,
                CheckpointByTime(checkpoint_path, save_weights_only=True)
            ]
            verbose = 1
        else:
            verbose = 0
        model.fit_generator(variational_monte_carlo.to_generator(),
                            steps_per_epoch=config.steps_per_epoch *
                            variational_monte_carlo.update_params_frequency,
                            epochs=total_epochs,
                            callbacks=callbacks,
                            max_queue_size=0,
                            workers=0,
                            initial_epoch=initial_epoch,
                            verbose=verbose)
        if is_rank_0:
            model.save_weights(
                os.path.join(config.output_path, 'stage_%s.h5' % (idx + 1)))
        initial_epoch = total_epochs

    evaluation_inputs = Input(shape=config.hilbert_state_shape, dtype='int8')
    obc_input = Input(shape=config.hilbert_state_shape,
                      dtype=evaluation_inputs.dtype)
    invariant_model = make_2d_obc_invariants(obc_input, model)
    invariant_model = make_up_down_invariant(evaluation_inputs,
                                             invariant_model)
    mini_batch_size = config.mini_batch_size // 16

    sampler = sampler.copy_with_new_batch_size(config.mini_batch_size)

    vmc_cls = VariationalMonteCarlo
    if config.use_horovod:
        vmc_cls = HorovodVariationalMonteCarlo

    variational_monte_carlo = vmc_cls(invariant_model,
                                      operator,
                                      sampler,
                                      mini_batch_size=config.mini_batch_size)
    callbacks = default_wave_function_stats_callbacks_factory(
        variational_monte_carlo,
        log_in_batch_or_epoch=False,
        true_ground_state_energy=true_ground_state_energy)
    if config.use_horovod:
        callbacks = callbacks + [hvd.callbacks.MetricAverageCallback()]
    results = evaluate(variational_monte_carlo,
                       steps=(2**15) // mini_batch_size,
                       callbacks=callbacks[:4],
                       keys_to_progress_bar_mapping={
                           'energy/energy': 'energy',
                           'energy/relative_error': 'relative_error',
                           'energy/local_energy_variance': 'variance'
                       },
                       verbose=is_rank_0)
    if is_rank_0:
        print(results)
示例#4
0
def run(params, batch_size_list, epochs_list):
    run_name = 'depth_%s_width_%s_weights_normalization_%s_adam_lr_%s_gamma_%s_run_%s' % (
        params.depth, params.width, params.no_weights_normalization,
        params.learning_rate, params.gamma, params.run_index)
    hilbert_state_shape = (12, 12)
    model, sampler = build_model(hilbert_state_shape, params.depth,
                                 params.width,
                                 not params.no_weights_normalization,
                                 params.learning_rate)
    operator = Ising(hilbert_state_shape=hilbert_state_shape,
                     pbc=False,
                     h=params.gamma)
    checkpoint_path = '%s.h5' % run_name
    initial_epoch = restore_run_state(model, checkpoint_path)
    total_epochs = 0
    mini_batch_size = depth_to_max_mini_batch(params.depth)
    true_ground_state_energy = true_ground_state_energy_mapping[params.gamma]

    for idx, (batch_size,
              epochs) in enumerate(zip(batch_size_list, epochs_list)):
        total_epochs += epochs
        if total_epochs <= initial_epoch:
            continue
        validation_sampler = sampler.copy_with_new_batch_size(
            min(batch_size * 8, 2**15), mini_batch_size=mini_batch_size)
        assert batch_size < mini_batch_size or batch_size % mini_batch_size == 0
        sampler = sampler.copy_with_new_batch_size(batch_size, mini_batch_size)
        variational_monte_carlo = VariationalMonteCarlo(
            model, operator, sampler, mini_batch_size=mini_batch_size)
        validation_generator = VariationalMonteCarlo(
            model,
            operator,
            validation_sampler,
            wave_function_evaluation_batch_size=mini_batch_size)
        model.optimizer.set_update_params_frequency(
            variational_monte_carlo.update_params_frequency)
        tensorboard = TensorBoardWithGeneratorValidationData(
            log_dir='tensorboard_logs/%s' % run_name,
            generator=variational_monte_carlo,
            update_freq='epoch',
            histogram_freq=0,
            batch_size=batch_size,
            write_output=False)
        callbacks = default_wave_function_stats_callbacks_factory(
            variational_monte_carlo,
            validation_generator=validation_generator,
            log_in_batch_or_epoch=False,
            true_ground_state_energy=true_ground_state_energy,
            validation_period=3)
        callbacks += [
            tensorboard,
            CheckpointByTime(checkpoint_path, save_weights_only=True)
        ]
        model.fit_generator(variational_monte_carlo.to_generator(),
                            steps_per_epoch=100 *
                            variational_monte_carlo.update_params_frequency,
                            epochs=total_epochs,
                            callbacks=callbacks,
                            max_queue_size=0,
                            workers=0,
                            initial_epoch=initial_epoch)
        model.save_weights('%s_stage_%s.h5' % (run_name, idx))
        initial_epoch = total_epochs

    evaluation_inputs = Input(shape=hilbert_state_shape, dtype='int8')
    obc_input = Input(shape=hilbert_state_shape, dtype=evaluation_inputs.dtype)
    invariant_model = make_2d_obc_invariants(obc_input, model)
    invariant_model = make_up_down_invariant(evaluation_inputs,
                                             invariant_model)
    mini_batch_size = mini_batch_size // 16

    sampler = sampler.copy_with_new_batch_size(mini_batch_size)
    variational_monte_carlo = VariationalMonteCarlo(
        invariant_model, operator, sampler, mini_batch_size=mini_batch_size)
    callbacks = default_wave_function_stats_callbacks_factory(
        variational_monte_carlo,
        log_in_batch_or_epoch=False,
        true_ground_state_energy=true_ground_state_energy)

    results = evaluate(variational_monte_carlo,
                       steps=(2**15) // mini_batch_size,
                       callbacks=callbacks[:4],
                       keys_to_progress_bar_mapping={
                           'energy/energy': 'energy',
                           'energy/relative_error': 'relative_error',
                           'energy/local_energy_variance': 'variance'
                       })
    print(results)
示例#5
0
# sampler = FastAutoregressiveSampler(fast_sampling, buffer_size=5000)
# sampler = FastAutoregressiveSampler(conditional_log_probs)
generator = VariationalMonteCarlo(model,
                                  operator,
                                  sampler,
                                  cache=wave_function_cache)
#### Exact Grads ####
# generator = ExactVariational(model, operator, batch_size, cache=wave_function_cache)
checkpoint = ModelCheckpoint('ising_fcnn.h5',
                             monitor='energy',
                             save_best_only=True,
                             save_weights_only=True)
# tensorboard = TensorBoard(update_freq=1)
tensorboard = TensorBoardWithGeneratorValidationData(
    monte_carlo_iterator=generator,
    update_freq=1,
    histogram_freq=1,
    batch_size=batch_size)
early_stopping = EarlyStopping(monitor='relative_energy_error', min_delta=1e-5)
callbacks = [
    LocalEnergyStats(generator, true_ground_state_energy=-457.0416241),
    wave_function_cache,
    LocalStats("Energy Again",
               sampler=sampler,
               operator=operator,
               cache=wave_function_cache),
    SigmaZStats(monte_carlo_generator=generator), checkpoint, tensorboard,
    early_stopping,
    TerminateOnNaN()
]
model.fit_generator(generator,