def test_synthetic_sample_reward_using_valid_inputs(context, action, description): n_actions = 10 dataset = SyntheticBanditDataset(n_actions=n_actions, dim_context=3) reward = dataset.sample_reward(context=context, action=action) assert isinstance(reward, np.ndarray), "Invalid response of sample_reward" assert reward.shape == action.shape, "Invalid response of sample_reward"
def test_synthetic_sample_reward_using_invalid_inputs(context, action, description): n_actions = 10 dataset = SyntheticBanditDataset(n_actions=n_actions) with pytest.raises(ValueError, match=f"{description}*"): _ = dataset.sample_reward(context=context, action=action)