예제 #1
0
 def policy_evaluation(self, policy, context, desired_action, delta, r,
                       epsilon):
     if policy != 'LinThompSamp':
         print("We don't support other bandit algorithms now!")
     else:
         historystorage = history.MemoryHistoryStorage()
         modelstorage = model.MemoryModelStorage()
         policy = linthompsamp.LinThompSamp(self.actions, historystorage,
                                            modelstorage, self.d, delta, r,
                                            epsilon)
         seq_error = np.zeros(shape=(self.t, 1))
         for t in range(self.t):
             history_id, action = policy.get_action(context[t])
             if desired_action[t][0] != action:
                 policy.reward(history_id, 0)
                 # sum_error += 1
                 if t == 0:
                     seq_error[t] = 1.0
                 else:
                     seq_error[t] = seq_error[t - 1] + 1.0
             else:
                 policy.reward(history_id, 1)
                 if t > 0:
                     seq_error[t] = seq_error[t - 1]
         return seq_error
예제 #2
0
def policy_generation(bandit, actions):
    historystorage = history.MemoryHistoryStorage()
    modelstorage = model.MemoryModelStorage()

    if bandit == 'Exp4P':
        policy = exp4p.Exp4P(actions,
                             historystorage,
                             modelstorage,
                             delta=0.5,
                             pmin=None)

    elif bandit == 'LinUCB':
        policy = linucb.LinUCB(actions, historystorage, modelstorage, 0.3, 20)

    elif bandit == 'LinThompSamp':
        policy = linthompsamp.LinThompSamp(actions,
                                           historystorage,
                                           modelstorage,
                                           d=20,
                                           delta=0.61,
                                           r=0.01,
                                           epsilon=0.71)

    elif bandit == 'UCB1':
        policy = ucb1.UCB1(actions, historystorage, modelstorage)

    elif bandit == 'Exp3':
        policy = exp3.Exp3(actions, historystorage, modelstorage, gamma=0.2)

    elif bandit == 'random':
        policy = 0

    return policy
예제 #3
0
 def test_update_reward(self):
     policy = linthompsamp.LinThompSamp(self.actions, self.historystorage,
                                        self.modelstorage, self.d,
                                        self.delta, self.R, self.epsilon)
     history_id, action = policy.get_action([[1, 1], [2, 2], [3, 3]])
     policy.reward(history_id, 1.0)
     self.assertEqual(
         policy._historystorage.get_history(history_id).reward, 1)
예제 #4
0
 def test_initialization(self):
     policy = linthompsamp.LinThompSamp(self.actions, self.historystorage,
                                        self.modelstorage, self.d,
                                        self.delta, self.R, self.epsilon)
     self.assertEqual(self.actions, policy._actions)
     self.assertEqual(self.d, policy.d)
     self.assertEqual(self.R, policy.R)
     self.assertEqual(self.epsilon, policy.epsilon)
예제 #5
0
 def test_model_storage(self):
     policy = linthompsamp.LinThompSamp(self.actions, self.historystorage,
                                        self.modelstorage, self.d,
                                        self.delta, self.R, self.epsilon)
     history_id, action = policy.get_action([[1, 1], [2, 2], [3, 3]])
     policy.reward(history_id, 1.0)
     self.assertTrue(
         (policy._modelstorage._model['B'].shape == (2, 2)) == True)
     self.assertEqual(len(policy._modelstorage._model['muhat']), 2)
     self.assertEqual(len(policy._modelstorage._model['f']), 2)
예제 #6
0
 def test_get_first_action(self):
     policy = linthompsamp.LinThompSamp(self.actions, self.historystorage,
                                        self.modelstorage, self.d,
                                        self.delta, self.R, self.epsilon)
     history_id, action = policy.get_action([[1, 1], [2, 2], [3, 3]])
     self.assertEqual(history_id, 0)
     self.assertIn(action, self.actions)
     self.assertTrue(
         (policy._historystorage.get_history(history_id).context == [[1, 1],
                                                                     [2, 2],
                                                                     [3,
                                                                      3]]))
예제 #7
0
def policy_generation(bandit, actions):
    """
    Parameters
    ----------
    bandit: 赌博机算法
    actions:动作即推荐的电影

    Returns
    -------
    policy: 生成的策略
    """
    historystorage = history.MemoryHistoryStorage()  # 内存中历史存储记录
    modelstorage = model.MemoryModelStorage()  # 内存中模型存储,为了统一

    if bandit == 'Exp4P':
        policy = exp4p.Exp4P(historystorage,
                             modelstorage,
                             actions,
                             delta=0.5,
                             p_min=None)

    elif bandit == 'LinUCB':
        #policy = linucb.LinUCB(historystorage, modelstorage, actions, 0.3, 20)
        policy = linucb.LinUCB(history_storage=historystorage,
                               model_storage=modelstorage,
                               action_storage=actions,
                               alpha=0.3,
                               context_dimension=18)

    elif bandit == 'LinThompSamp':
        policy = linthompsamp.LinThompSamp(
            historystorage,
            modelstorage,
            actions,
            #d=20, Supposed to be context dimension
            context_dimension=18,
            delta=0.61,
            R=0.01,
            epsilon=0.71)

    elif bandit == 'UCB1':
        policy = ucb1.UCB1(historystorage, modelstorage, actions)

    elif bandit == 'Exp3':
        policy = exp3.Exp3(historystorage, modelstorage, actions, gamma=0.2)

    elif bandit == 'random':
        policy = 0

    return policy
예제 #8
0
 def test_reward_order_descending(self):
     policy = linthompsamp.LinThompSamp(self.actions, self.historystorage,
                                        self.modelstorage, self.d,
                                        self.delta, self.R, self.epsilon)
     history_id, action = policy.get_action([[1, 1], [2, 2], [3, 3]])
     history_id_2, action_2 = policy.get_action([[0, 1], [2, 3], [7, 5]])
     policy.reward(history_id_2, 1)
     self.assertEqual(
         policy._historystorage.get_history(history_id).context,
         [[1, 1], [2, 2], [3, 3]])
     self.assertEqual(
         policy._historystorage.get_history(history_id_2).context,
         [[0, 1], [2, 3], [7, 5]])
     self.assertEqual(
         policy._historystorage.get_history(history_id).reward, None)
     self.assertEqual(
         policy._historystorage.get_history(history_id_2).reward, 1)
예제 #9
0
def policy_generation(bandit, actions):
    historystorage = history.MemoryHistoryStorage()
    modelstorage = model.MemoryModelStorage()

    if bandit == 'Exp4P':
        policy = exp4p.Exp4P(historystorage,
                             modelstorage,
                             actions,
                             delta=0.5,
                             p_min=None)

    elif bandit == 'LinUCB':
        #policy = linucb.LinUCB(historystorage, modelstorage, actions, 0.3, 20)
        policy = linucb.LinUCB(history_storage=historystorage,
                               model_storage=modelstorage,
                               action_storage=actions,
                               alpha=0.3,
                               context_dimension=18)

    elif bandit == 'LinThompSamp':
        policy = linthompsamp.LinThompSamp(
            historystorage,
            modelstorage,
            actions,
            #d=20, Supposed to be context dimension
            context_dimension=18,
            delta=0.61,
            R=0.01,
            epsilon=0.71)

    elif bandit == 'UCB1':
        policy = ucb1.UCB1(historystorage, modelstorage, actions)

    elif bandit == 'Exp3':
        policy = exp3.Exp3(historystorage, modelstorage, actions, gamma=0.2)

    elif bandit == 'random':
        policy = 0

    return policy