def test_insert_sym_when_already_exists(self):
        """Test insertion of an additional symbol to a normalized list of
        symbols."""

        syms = [('A', 0.25), ('B', 0.25), ('C', 0.25), ('D', 0.25)]
        self.assertEqual(1.0, sum([prob for _, prob in syms]))

        new_list = sym_appended(syms, ('D', 0.25))
        self.assertEqual(syms, new_list, "Value already present")

        new_list = sym_appended(syms, ('D', 0.2))
        self.assertEqual(
            syms, new_list, msg="Changing the probability does not matter")
Example #2
0
    def initialize_epoch(self):
        """If a decision is made initializes the next epoch."""

        try:
            # First, reset the history for this new epoch
            self.conjugator.reset_history()

            # If there is no language model specified, mock the LM prior
            # TODO: is the probability domain correct? ERP evidence is in
            # the log domain; LM by default returns negative log domain.
            if not self.lmodel:
                # mock probabilities to be equally likely for all letters.
                overrides = {BACKSPACE_CHAR: self.backspace_prob}
                prior = equally_probable(self.alp, overrides)

            # Else, let's query the lmodel for priors
            else:
                # Get the displayed state
                # TODO: for oclm this should be a list of (sym, prob)
                update = self.decision_maker.displayed_state

                # update the lmodel and get back the priors
                lm_prior = self.lmodel.state_update(update)

                # normalize to probability domain if needed
                if getattr(self.lmodel, 'normalized', False):
                    lm_letter_prior = lm_prior['letter']
                else:
                    lm_letter_prior = norm_domain(lm_prior['letter'])

                if BACKSPACE_CHAR in self.alp:
                    # Append backspace if missing.
                    sym = (BACKSPACE_CHAR, self.backspace_prob)
                    lm_letter_prior = sym_appended(lm_letter_prior, sym)

                # convert to format needed for evidence fusion;
                # probability value only in alphabet order.
                # TODO: ensure that probabilities still add to 1.0
                prior = [
                    prior_prob for alp_letter in self.alp
                    for prior_sym, prior_prob in lm_letter_prior
                    if alp_letter == prior_sym
                ]

            # Try fusing the lmodel evidence
            try:
                prob_dist = self.conjugator.update_and_fuse(
                    {'LM': np.array(prior)})
            except Exception as lm_exception:
                print("Error updating language model!")
                raise lm_exception

            # Get decision maker to give us back some decisions and stimuli
            is_accepted, sti = self.decision_maker.decide(prob_dist)

        except Exception as init_exception:
            print("Error in initialize_epoch: %s" % (init_exception))
            raise init_exception

        return is_accepted, sti
    def test_insert_sym_with_non_zero_prob(self):
        """Test insertion of an additional symbol to a normalized list of
        symbols with a non-zero probability."""

        syms = [('A', 0.25), ('B', 0.25), ('C', 0.25), ('D', 0.25)]

        self.assertEqual(1.0, sum([prob for _, prob in syms]))

        new_sym = ('<', 0.2)

        new_list = sym_appended(syms, new_sym)
        self.assertEqual(len(syms) + 1, len(new_list))
        self.assertAlmostEqual(1.0, sum([prob for _, prob in new_list]))

        new_list_dict = dict(new_list)
        prev_list_dict = dict(syms)
        for s, _ in syms:
            self.assertTrue(s in new_list_dict)
            self.assertTrue(new_list_dict[s] < prev_list_dict[s])
            self.assertEqual(0.2, new_list_dict[s])
        self.assertTrue(new_sym[0] in new_list_dict)
        self.assertEqual(new_sym[1], new_list_dict[new_sym[0]])
    def test_insert_sym_with_zero_prob(self):
        """Test insertion of an additional symbol (with zero probability) to a
        normalized list of symbols."""

        syms = [
            ('S', 0.21999999999999953), ('U', 0.03), ('O', 0.03), ('M', 0.03),
            ('W', 0.03), ('T', 0.03), ('P', 0.03), ('R', 0.03), ('L', 0.03),
            ('N', 0.03), ('_', 0.03), ('C', 0.03), ('E', 0.03), ('B', 0.03),
            ('A', 0.03), ('D', 0.03), ('H', 0.03), ('G', 0.03), ('F', 0.03),
            ('V', 0.03), ('K', 0.03), ('I', 0.03), ('Y', 0.03), ('J', 0.03),
            ('X', 0.03), ('Z', 0.03), ('Q', 0.03)
        ]

        self.assertEqual(1.0, sum([prob for _, prob in syms]))

        new_list = sym_appended(syms, ('<', 0.0))
        self.assertEqual(len(syms) + 1, len(new_list))
        self.assertEqual(1.0, sum([prob for _, prob in new_list]))

        for pair in syms:
            self.assertTrue(pair in new_list)
        self.assertTrue('<' in dict(new_list))
    def test_small_probs(self):
        """When very small values are returned from the LM, inserting a letter
        should still result in all positive values"""
        probs = [('_', 0.8137718053286306), ('R', 0.04917114015944412),
                 ('Y', 0.04375449276342169), ('I', 0.03125895356629575),
                 ('M', 0.023673042636520744), ('S', 0.018415576386909806),
                 ('N', 0.014673750822550981), ('O', 0.003311888694636908),
                 ('A', 0.0015325727808248953), ('E', 0.00020663161460758318),
                 ('F', 0.0001271103705188304), ('L', 7.17785373200501e-05),
                 ('T', 1.9445808941289728e-05), ('V', 8.947029414950125e-06),
                 ('D', 1.3287314209822164e-06), ('W', 5.781802939202195e-07),
                 ('C', 4.956713702874677e-07), ('U', 1.3950738615826266e-07),
                 ('J', 6.925441151532328e-08), ('H', 6.067034614011934e-08),
                 ('B', 4.83364892878174e-08), ('K', 4.424005264637171e-08),
                 ('P', 3.319195566216423e-08), ('Z', 2.9048107858874575e-08),
                 ('X', 1.904356496190087e-08), ('Q', 1.0477572781302604e-08),
                 ('G', 7.146978265955833e-09)]

        syms = sym_appended(probs, ('<', 0.05))
        self.assertAlmostEqual(1.0, sum([sym[1] for sym in syms]))
        for sym in syms:
            self.assertTrue(sym[1] >= 0)