Пример #1
0
    def test_sum_one_info_action_match_predict(self):
        learner = SafeLearner(UnsafeFixedLearner([1 / 2, 1 / 2], 1))

        predict = learner.predict(None, [1, 2])

        self.assertEqual([1 / 2, 1 / 2], predict[0])
        self.assertEqual(1, predict[1])
Пример #2
0
    def test_params_fullname_2_params(self):
        learner = SafeLearner(
            ParamsLearner({
                'a': "A",
                "b": "B",
                "family": "B"
            }))

        self.assertEqual("B(a=A,b=B)", learner.full_name)
Пример #3
0
    def process(
            self, learner: Learner, interactions: Iterable[LoggedInteraction]
    ) -> Iterable[Dict[Any, Any]]:

        learn_time = 0
        predict_time = 0

        if not isinstance(learner, SafeLearner): learner = SafeLearner(learner)
        if not interactions: return

        for interaction in interactions:

            InteractionContext.learner_info.clear()

            if "actions" not in interaction.kwargs:
                info = None
                interaction_info = {}

            if "actions" in interaction.kwargs:
                actions = list(interaction.kwargs["actions"])
                start_time = time.time()
                probs, info = learner.predict(interaction.context, actions)
                predict_time = time.time() - start_time
                interaction_info = {}

            if len(interaction.kwargs.keys()
                   & {"probability", "actions", "reward"}) == 3:
                ratio = probs[actions.index(
                    interaction.action)] / interaction.kwargs["probability"]
                interaction_info = {
                    'reward': ratio * interaction.kwargs['reward']
                }

            for k, v in interaction.kwargs.items():
                if k not in ["probability", "actions", "reward"]:
                    interaction_info[k] = v

            reveal = interaction.kwargs.get("reveal",
                                            interaction.kwargs.get("reward"))
            prob = interaction.kwargs.get("probability")

            start_time = time.time()
            learner.learn(interaction.context, interaction.action, reveal,
                          prob, info)
            learn_time = time.time() - start_time

            learner_info = InteractionContext.learner_info
            time_info = {
                "predict_time": predict_time,
                "learn_time": learn_time
            } if self._time else {}

            yield {**interaction_info, **learner_info, **time_info}
Пример #4
0
    def process(
        self, learner: Learner, interactions: Iterable[SimulatedInteraction]
    ) -> Iterable[Dict[Any, Any]]:

        random = CobaRandom(self._seed)

        if not isinstance(learner, SafeLearner): learner = SafeLearner(learner)
        if not interactions: return

        for interaction in interactions:

            InteractionContext.learner_info.clear()

            context = interaction.context
            actions = interaction.actions

            start_time = time.time()
            probs, info = learner.predict(context, actions)
            predict_time = time.time() - start_time

            action = random.choice(actions, probs)
            reveal = interaction.kwargs.get(
                "reveals",
                interaction.kwargs.get("rewards"))[actions.index(action)]
            prob = probs[actions.index(action)]

            start_time = time.time()
            learner.learn(context, action, reveal, prob, info)
            learn_time = time.time() - start_time

            learner_info = InteractionContext.learner_info
            interaction_info = {}

            for k, v in interaction.kwargs.items():
                if isinstance(v, collections.abc.Sequence) and not isinstance(
                        v, str):
                    interaction_info[k] = v[actions.index(action)]
                else:
                    interaction_info[k] = v

            time_info = {
                "predict_time": predict_time,
                "learn_time": learn_time
            } if self._time else {}

            yield {**interaction_info, **learner_info, **time_info}
Пример #5
0
    def process(
            self, learner: Learner,
            interactions: Iterable[Interaction]) -> Iterable[Dict[Any, Any]]:

        if not isinstance(learner, SafeLearner): learner = SafeLearner(learner)
        if not interactions: return

        separable_interactions = iter(
            self._repeat_first_simulated_interaction(interactions))

        logged_interactions = takewhile(
            lambda i: isinstance(i, LoggedInteraction), separable_interactions)
        simulated_interactions = separable_interactions

        for row in OnlineOffPolicyEvalTask(self._time).process(
                learner, logged_interactions):
            yield row

        for row in OnlineOnPolicyEvalTask(self._time, self._seed).process(
                learner, simulated_interactions):
            yield row
Пример #6
0
 def process(self, item: Learner) -> Dict[Any, Any]:
     item = SafeLearner(item)
     return {"full_name": item.full_name, **item.params}
Пример #7
0
    def test_sum_one_info_action_mismatch_predict(self):
        learner = SafeLearner(UnsafeFixedLearner([1 / 2, 1 / 2], 1))

        with self.assertRaises(AssertionError):
            learner.predict(None, [1, 2, 3])
Пример #8
0
    def test_no_sum_one_no_info_action_match_predict(self):
        learner = SafeLearner(UnsafeFixedLearner([1 / 3, 1 / 2], None))

        with self.assertRaises(AssertionError):
            learner.predict(None, [1, 2])
Пример #9
0
    def test_params_fullname_0_params(self):
        learner = SafeLearner(ParamsLearner({"family": "B"}))

        self.assertEqual("B", learner.full_name)
Пример #10
0
 def test_params_family(self):
     learner = SafeLearner(ParamsLearner({'a': "A", "family": "B"}))
     self.assertDictEqual({"family": "B", "a": "A"}, learner.params)
Пример #11
0
 def test_no_params(self):
     learner = SafeLearner(NoParamsLearner())
     self.assertEqual("NoParamsLearner", learner.params["family"])