def test_discrete_sac(observation_shape, action_size, q_func_factory, scalers): scaler, reward_scaler = scalers sac = DiscreteSAC( q_func_factory=q_func_factory, scaler=scaler, reward_scaler=reward_scaler, ) algo_tester(sac, observation_shape, test_policy_copy=True, test_q_function_copy=True) algo_update_tester(sac, observation_shape, action_size, discrete=True)
def test_discrete_sac_performance(q_func_factory): if q_func_factory == "iqn" or q_func_factory == "fqf": pytest.skip("IQN is computationally expensive") sac = DiscreteSAC(q_func_factory=q_func_factory) algo_cartpole_tester(sac, n_trials=3)
def test_discrete_sac(observation_shape, action_size, q_func_factory, scaler): sac = DiscreteSAC(q_func_factory=q_func_factory, scaler=scaler) algo_tester(sac, observation_shape) algo_update_tester(sac, observation_shape, action_size, discrete=True)
def test_discrete_sac_performance(q_func_type): if q_func_type == 'iqn' or q_func_type == 'fqf': pytest.skip('IQN is computationally expensive') sac = DiscreteSAC(q_func_type=q_func_type) algo_cartpole_tester(sac, n_trials=3)