def algo_tester(algo, observation_shape, imitator=False, action_size=2, state_value=False): # dummy impl object impl = DummyImpl(observation_shape, action_size) base_tester(algo, impl, observation_shape, action_size) algo._impl = impl # check save policy impl.save_policy = Mock() algo.save_policy("policy.pt", False) impl.save_policy.assert_called_with("policy.pt", False) # check predict x = np.random.random((2, 3)).tolist() ref_y = np.random.random((2, action_size)).tolist() impl.predict_best_action = Mock(return_value=ref_y) y = algo.predict(x) assert y == ref_y impl.predict_best_action.assert_called_with(x) # check predict_value if not imitator and not state_value: action = np.random.random((2, action_size)).tolist() ref_value = np.random.random((2, 3)).tolist() impl.predict_value = Mock(return_value=ref_value) value = algo.predict_value(x, action) assert value == ref_value impl.predict_value.assert_called_with(x, action, False) # check sample_action impl.sample_action = Mock(return_value=ref_y) try: y = algo.sample_action(x) assert y == ref_y impl.sample_action.assert_called_with(x) except NotImplementedError: pass algo._impl = None
def ope_tester(ope, observation_shape, action_size=2): # dummy impl object impl = DummyImpl(observation_shape, action_size) base_tester(ope, impl, observation_shape, action_size) ope._algo.impl = impl ope.impl = impl # check save policy impl.save_policy = Mock() ope.save_policy("policy.pt", False) impl.save_policy.assert_called_with("policy.pt", False) # check predict x = np.random.random((2, 3)).tolist() ref_y = np.random.random((2, action_size)).tolist() impl.predict_best_action = Mock(return_value=ref_y) y = ope.predict(x) assert y == ref_y impl.predict_best_action.assert_called_with(x) # check predict_value action = np.random.random((2, action_size)).tolist() ref_value = np.random.random((2, 3)).tolist() impl.predict_value = Mock(return_value=ref_value) value = ope.predict_value(x, action) assert value == ref_value impl.predict_value.assert_called_with(x, action, False) # check sample_action impl.sample_action = Mock(return_value=ref_y) try: y = ope.sample_action(x) assert y == ref_y impl.sample_action.assert_called_with(x) except NotImplementedError: pass ope.impl = None ope._algo.impl = None
def dynamics_tester(dynamics, observation_shape, action_size=2): # dummy impl object impl = DummyImpl(observation_shape, action_size) base_tester(dynamics, impl, observation_shape, action_size) dynamics._impl = impl # check predict x = np.random.random((2, 3)).tolist() action = np.random.random((2, 3)).tolist() ref_y = np.random.random((2, 3)).tolist() ref_reward = np.random.random((2, 1)).tolist() ref_variance = np.random.random((2, 1)).tolist() impl.predict = Mock(return_value=(ref_y, ref_reward, ref_variance)) y, reward = dynamics.predict(x, action) assert y == ref_y assert reward == ref_reward impl.predict.assert_called_with(x, action) # check with_variance y, reward, variance = dynamics.predict(x, action, with_variance=True) assert variance == ref_variance
def dynamics_tester( dynamics, observation_shape, action_size=2, discrete_action=False ): # dummy impl object impl = DummyImpl(observation_shape, action_size) base_tester(dynamics, impl, observation_shape, action_size) dynamics.create_impl(observation_shape, action_size) # check predict x = np.random.random((2, *observation_shape)) if discrete_action: action = np.random.randint(action_size, size=2) else: action = np.random.random((2, action_size)) y, reward = dynamics.predict(x, action) assert y.shape == (2, *observation_shape) assert reward.shape == (2, 1) # check with_variance y, reward, variance = dynamics.predict(x, action, with_variance=True) assert variance.shape == (2, 1)
def algo_tester( algo, observation_shape, imitator=False, action_size=2, state_value=False, test_policy_copy=False, test_q_function_copy=False, ): # dummy impl object impl = DummyImpl(observation_shape, action_size) base_tester(algo, impl, observation_shape, action_size) algo._impl = impl # check save policy impl.save_policy = Mock() algo.save_policy("policy.pt", False) impl.save_policy.assert_called_with("policy.pt", False) # check predict x = np.random.random((2, 3)).tolist() ref_y = np.random.random((2, action_size)).tolist() impl.predict_best_action = Mock(return_value=ref_y) y = algo.predict(x) assert y == ref_y impl.predict_best_action.assert_called_with(x) # check predict_value if not imitator and not state_value: action = np.random.random((2, action_size)).tolist() ref_value = np.random.random((2, 3)).tolist() impl.predict_value = Mock(return_value=ref_value) value = algo.predict_value(x, action) assert value == ref_value impl.predict_value.assert_called_with(x, action, False) # check sample_action impl.sample_action = Mock(return_value=ref_y) try: y = algo.sample_action(x) assert y == ref_y impl.sample_action.assert_called_with(x) except NotImplementedError: pass algo.create_impl(observation_shape, action_size) if test_policy_copy: algo2 = algo.__class__(**algo.get_params()) algo2.create_impl(observation_shape, action_size) algo.copy_policy_from(algo2) observations = np.random.random((100, *observation_shape)) action1 = algo.predict(observations) action2 = algo.predict(observations) assert np.all(action1 == action2) if test_q_function_copy: algo2 = algo.__class__(**algo.get_params()) algo2.create_impl(observation_shape, action_size) algo.copy_q_function_from(algo2) observations = np.random.random((100, *observation_shape)) if algo.get_action_type() == ActionSpace.CONTINUOUS: actions = np.random.random((100, action_size)) else: actions = np.random.randint(action_size, size=100) value1 = algo.predict_value(observations, actions) value2 = algo2.predict_value(observations, actions) assert np.all(value1 == value2) algo._impl = None