PPO.train_eval(
     root_dir=sim_dir,
     random_seed=seed,
     num_epochs=4000,
     # Params for train
     normalize_observations=True,
     normalize_rewards=False,
     discount_factor=1.0,
     lr=1e-4,
     lr_schedule=lambda x: 1e-3 if x < 500 else 1e-4,
     num_policy_updates=20,
     initial_adaptive_kl_beta=0.0,
     kl_cutoff_factor=0,
     importance_ratio_clipping=0.1,
     value_pred_loss_coef=0.005,
     # Params for log, eval, save
     eval_interval=50,
     save_interval=50,
     checkpoint_interval=10000,
     summary_interval=10000,
     # Params for data collection
     train_batch_size=train_batch_size,
     eval_batch_size=eval_batch_size,
     collect_driver=collect_driver,
     eval_driver=eval_driver,
     replay_buffer_capacity=7000,
     # Policy and value networks
     ActorNet=actor_distribution_network_gkp.
     ActorDistributionNetworkGKP,
     actor_fc_layers=(),
     value_fc_layers=(),
     use_rnn=True,
     actor_lstm_size=(12, ),
     value_lstm_size=(12, ))
コード例 #2
0
PPO.train_eval(
    root_dir=root_dir,
    random_seed=0,
    num_epochs=300,
    # Params for train
    normalize_observations=True,
    normalize_rewards=False,
    discount_factor=1.0,
    lr=1e-3,
    lr_schedule=None,
    num_policy_updates=20,
    initial_adaptive_kl_beta=0.0,
    kl_cutoff_factor=0,
    importance_ratio_clipping=0.1,
    value_pred_loss_coef=0.005,
    gradient_clipping=1.0,
    entropy_regularization=0,
    # Params for log, eval, save
    eval_interval=10,
    save_interval=10,
    checkpoint_interval=10000,
    summary_interval=10,
    # Params for data collection
    train_batch_size=train_batch_size,
    eval_batch_size=eval_batch_size,
    collect_driver=collect_driver,
    eval_driver=eval_driver,
    replay_buffer_capacity=15000,
    # Policy and value networks
    ActorNet=actor_distribution_network.ActorDistributionNetwork,
    zero_means_kernel_initializer=True,
    actor_fc_layers=(),
    value_fc_layers=(),
    use_rnn=False,
    actor_lstm_size=(12, ),
    value_lstm_size=(12, ))
コード例 #3
0
PPO.train_eval(
    root_dir=root_dir,
    random_seed=0,
    # Params for collect
    num_iterations=100000,
    train_batch_size=100,
    replay_buffer_capacity=15000,
    # Params for train
    normalize_observations=True,
    normalize_rewards=False,
    discount_factor=1.0,
    lr=3e-4,
    lr_schedule=None,
    num_policy_epochs=20,
    initial_adaptive_kl_beta=0.0,
    kl_cutoff_factor=0,
    importance_ratio_clipping=0.1,
    value_pred_loss_coef=0.005,
    # Params for log, eval, save
    eval_batch_size=1000,
    eval_interval=100,
    save_interval=500,
    checkpoint_interval=5000,
    summary_interval=100,
    # Params for environment
    simulate='Alec_universal_gate_set',
    horizon=1,
    clock_period=6,
    attention_step=1,
    train_episode_length=lambda x: 6,
    eval_episode_length=6,
    init_state='vac',
    reward_kwargs=reward_kwargs,
    encoding='square',
    action_script='Alec_universal_gate_set_6round',
    to_learn={
        'alpha': True,
        'beta': True,
        'phi': True
    },
    # Policy and value networks
    ActorNet=actor_distribution_network_gkp.ActorDistributionNetworkGKP,
    actor_fc_layers=(),
    value_fc_layers=(),
    use_rnn=True,
    actor_lstm_size=(12, ),
    value_lstm_size=(12, ),
    **kwargs)