def avg_reward(action_script):
     policy = plc.ScriptedPolicy(env.time_step_spec(),
                                 action_script)
     time_step = env.reset()
     policy_state = policy.get_initial_state(env.batch_size)
     while not time_step.is_last()[0]:
         action_step = policy.action(time_step, policy_state)
         policy_state = action_step.state
         time_step = env.step(action_step.action)
     return time_step.reward.numpy().mean()
Exemplo n.º 2
0
    policy = tf.compat.v2.saved_model.load(os.path.join(root_dir, policy_dir))

### simulate GKP stabilization with SBS
if 0:
    env = env_init(control_circuit='gkp_qec_autonomous_sBs_osc_qb',
                   reward_kwargs={'reward_mode': 'zero'},
                   init='vac',
                   H=1,
                   T=2,
                   attn_step=1,
                   batch_size=1,
                   episode_length=10,
                   encoding='square')

    from rl_tools.action_script import gkp_qec_autonomous_sBs_2round as action_script
    policy = plc.ScriptedPolicy(env.time_step_spec(), action_script)

### simulate GKP state preparation with ECDC
if 0:
    env = env_init(control_circuit='ECD_control',
                   reward_kwargs=dict(reward_mode='zero'),
                   init='vac',
                   T=11,
                   batch_size=1,
                   N=100,
                   episode_length=11)

    from rl_tools.action_script import ECD_control_residuals_GKP_plusX_hex as action_script
    policy = plc.ScriptedPolicy(env.time_step_spec(), action_script)

# N=40