Exemple #1
0
    def test_overlap(self):
        cretention = retention_cls()

        # ----------------------------------------------
        d_target = OrderedDict([(("M1", "A"), 1), (("M2", "A"), 2),
                                (("M3", "A"), 3), (("M4", "A"), 4),
                                (("M2", "B"), 2), (("M3", "B"), 1)])
        keys = list(d_target.keys())

        cretention.load_data_from_target(d_target)
        cretention.make_digraph()
        cretention.dmolecules_inv = cretention.invert_dictionary(
            cretention.dmolecules)
        cretention.dcollections_inv = cretention.invert_dictionary(
            cretention.dcollections)

        pairs = get_pairs_from_order_graph(cretention,
                                           keys,
                                           allow_overlap=True,
                                           d_lower=0,
                                           d_upper=np.inf)

        self.assertEqual(len(pairs), 17)
        self.assertIn((0, 1), pairs)
        self.assertIn((0, 2), pairs)
        self.assertIn((0, 3), pairs)
        self.assertIn((0, 5), pairs)
        self.assertIn((0, 4), pairs)
        self.assertIn((1, 2), pairs)
        self.assertIn((1, 3), pairs)
        self.assertIn((1, 5), pairs)
        self.assertIn((2, 3), pairs)
        self.assertIn((2, 4), pairs)
        self.assertIn((2, 1), pairs)
        self.assertIn((5, 4), pairs)
        self.assertIn((5, 1), pairs)
        self.assertIn((5, 3), pairs)
        self.assertIn((4, 2), pairs)
        self.assertIn((4, 5), pairs)
        self.assertIn((4, 3), pairs)

        pairs = get_pairs_from_order_graph(cretention,
                                           keys,
                                           allow_overlap=False,
                                           d_lower=0,
                                           d_upper=np.inf)

        self.assertEqual(len(pairs), 9)
        self.assertIn((0, 1), pairs)
        self.assertIn((0, 2), pairs)
        self.assertIn((0, 3), pairs)
        self.assertIn((0, 5), pairs)
        self.assertIn((0, 4), pairs)
        self.assertIn((1, 3), pairs)
        self.assertIn((2, 3), pairs)
        self.assertIn((5, 3), pairs)
        self.assertIn((4, 3), pairs)
Exemple #2
0
    def test_ireversed(self):
        cretention = retention_cls()

        # ----------------------------------------------
        d_target = OrderedDict([(("M1", "A"), 1), (("M2", "A"), 2),
                                (("M3", "A"), 3), (("M4", "A"), 4),
                                (("M2", "B"), 1), (("M3", "B"), 2),
                                (("M7", "B"), 3)])
        keys = list(d_target.keys())

        cretention.load_data_from_target(d_target)
        cretention.make_digraph(ireverse=0)
        cretention.dmolecules_inv = cretention.invert_dictionary(
            cretention.dmolecules)
        cretention.dcollections_inv = cretention.invert_dictionary(
            cretention.dcollections)

        pairs_ref = [(0, 1), (1, 2), (2, 3), (4, 5), (5, 6), (0, 2), (1, 3),
                     (4, 6), (0, 3)]
        pairs_notin_ref = [(1, 6), (1, 5), (0, 5), (0, 4)]

        for allow_overlap in [True, False]:
            pairs = get_pairs_from_order_graph(cretention,
                                               keys,
                                               allow_overlap=allow_overlap,
                                               d_lower=0,
                                               d_upper=np.inf)

            self.assertEqual(len(pairs), len(pairs_ref))

            for pair in pairs_ref:
                self.assertIn(pair, pairs)

            for pair in pairs_notin_ref:
                self.assertNotIn(pair, pairs)
Exemple #3
0
    def test_simplecases(self):
        cretention = retention_cls()

        # ----------------------------------------------
        d_target = OrderedDict([(("M1", "A"), 1), (("M2", "A"), 2),
                                (("M3", "A"), 3), (("M4", "A"), 4),
                                (("M1", "B"), 1), (("M5", "B"), 2)])
        keys = list(d_target.keys())

        cretention.load_data_from_target(d_target)
        cretention.make_digraph()
        cretention.dmolecules_inv = cretention.invert_dictionary(
            cretention.dmolecules)
        cretention.dcollections_inv = cretention.invert_dictionary(
            cretention.dcollections)

        for allow_overlap in [True, False]:
            pairs = get_pairs_from_order_graph(cretention,
                                               keys,
                                               allow_overlap=allow_overlap,
                                               d_lower=0,
                                               d_upper=np.inf)
            self.assertEqual(len(pairs), 11)
            self.assertIn((0, 1), pairs)
            self.assertIn((0, 2), pairs)
            self.assertIn((0, 3), pairs)
            self.assertIn((0, 5), pairs)
            self.assertIn((1, 2), pairs)
            self.assertIn((1, 3), pairs)
            self.assertIn((2, 3), pairs)
            self.assertIn((4, 1), pairs)
            self.assertIn((4, 2), pairs)
            self.assertIn((4, 3), pairs)
            self.assertIn((4, 5), pairs)

        # ----------------------------------------------
        d_target = OrderedDict([(("M1", "A"), 1), (("M2", "A"), 2),
                                (("M3", "A"), 3), (("M4", "A"), 4),
                                (("M2", "B"), 1), (("M3", "B"), 2),
                                (("M7", "B"), 3)])
        keys = list(d_target.keys())

        cretention.load_data_from_target(d_target)
        cretention.make_digraph()
        cretention.dmolecules_inv = cretention.invert_dictionary(
            cretention.dmolecules)
        cretention.dcollections_inv = cretention.invert_dictionary(
            cretention.dcollections)

        for allow_overlap in [True, False]:
            pairs = get_pairs_from_order_graph(cretention,
                                               keys,
                                               allow_overlap=allow_overlap,
                                               d_lower=0,
                                               d_upper=np.inf)
            self.assertEqual(len(pairs), 18)
            self.assertIn((0, 1), pairs)
            self.assertIn((0, 2), pairs)
            self.assertIn((0, 3), pairs)
            self.assertIn((0, 4), pairs)
            self.assertIn((0, 5), pairs)
            self.assertIn((0, 6), pairs)
            self.assertIn((1, 2), pairs)
            self.assertIn((1, 3), pairs)
            self.assertIn((1, 5), pairs)
            self.assertIn((1, 6), pairs)
            self.assertIn((2, 3), pairs)
            self.assertIn((2, 6), pairs)
            self.assertIn((4, 2), pairs)
            self.assertIn((4, 3), pairs)
            self.assertIn((4, 5), pairs)
            self.assertIn((4, 6), pairs)
            self.assertIn((5, 3), pairs)
            self.assertIn((5, 6), pairs)
Exemple #4
0
    def test_equal_to_simple_function_in_multiple_system_case(self):
        cretention = retention_cls()

        # ----------------------------------------------
        d_target = OrderedDict([(("M1", "A"), 5), (("M2", "A"), 2),
                                (("M3", "A"), 3), (("M4", "A"), 4),
                                (("M5", "A"), 1), (("M5", "B"), 1),
                                (("M9", "B"), 10), (("M7", "B"), 5),
                                (("M1", "B"), 12)])
        keys = list(d_target.keys())

        cretention.load_data_from_target(d_target)
        cretention.make_digraph(ireverse=0)
        cretention.dmolecules_inv = cretention.invert_dictionary(
            cretention.dmolecules)
        cretention.dcollections_inv = cretention.invert_dictionary(
            cretention.dcollections)

        d_pairs_ref = {
            0: [],
            1: [(4, 1), (1, 2), (2, 3), (3, 0), (5, 7), (7, 6), (6, 8)],
            2: [(4, 1), (1, 2), (2, 3), (3, 0), (5, 7), (7, 6), (6, 8), (4, 2),
                (1, 3), (2, 0), (5, 6), (7, 8)],
            3: [(4, 1), (1, 2), (2, 3), (3, 0), (5, 7), (7, 6), (6, 8), (4, 2),
                (1, 3), (2, 0), (5, 6), (7, 8), (4, 3), (1, 0), (5, 8)],
            4: [(4, 1), (1, 2), (2, 3), (3, 0), (5, 7), (7, 6), (6, 8), (4, 2),
                (1, 3), (2, 0), (5, 6), (7, 8), (4, 3), (1, 0), (5, 8), (4, 0)]
        }
        for d in d_pairs_ref.keys():
            pairs_og = get_pairs_from_order_graph(cretention,
                                                  keys,
                                                  allow_overlap=True,
                                                  d_lower=0,
                                                  d_upper=d)
            m_target = np.array(
                [list(d_target.values()), [1, 1, 1, 1, 1, 2, 2, 2, 2]]).T
            pairs = get_pairs_multiple_systems(m_target, d_lower=0, d_upper=d)

            self.assertEqual(len(pairs_og), len(d_pairs_ref[d]))
            self.assertEqual(len(pairs), len(d_pairs_ref[d]))

            for pair in d_pairs_ref[d]:
                self.assertIn(pair, pairs_og)
                self.assertIn(pair, pairs)

        # ----------------------------------------------
        d_target = OrderedDict([(("M1", "A"), 1), (("M2", "A"), 2),
                                (("M3", "A"), 3), (("M4", "A"), 4),
                                (("M2", "B"), 1), (("M3", "B"), 2),
                                (("M7", "B"), 1.5)])
        keys = list(d_target.keys())

        cretention.load_data_from_target(d_target)
        cretention.make_digraph(ireverse=False)
        cretention.dmolecules_inv = cretention.invert_dictionary(
            cretention.dmolecules)
        cretention.dcollections_inv = cretention.invert_dictionary(
            cretention.dcollections)

        d_pairs_ref = {
            4: [],
            3: [(0, 3)],
            2: [(0, 3), (0, 2), (1, 3), (4, 5)],
            1: [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3), (4, 5), (4, 6),
                (6, 5)],
            0: [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3), (4, 5), (4, 6),
                (6, 5)]
        }

        for allow_overlap in [True, False]:
            for d in d_pairs_ref.keys():
                pairs = get_pairs_from_order_graph(cretention,
                                                   keys,
                                                   allow_overlap=allow_overlap,
                                                   d_lower=d,
                                                   d_upper=np.inf)

                m_target = np.array(
                    [list(d_target.values()), [1, 1, 1, 1, 2, 2, 2]]).T
                pairs_sf = get_pairs_multiple_systems(m_target,
                                                      d_lower=d,
                                                      d_upper=np.inf)

                self.assertEqual(len(pairs), len(d_pairs_ref[d]))
                self.assertEqual(len(pairs_sf), len(d_pairs_ref[d]))

                for pair in d_pairs_ref[d]:
                    self.assertIn(pair, pairs)
                    self.assertIn(pair, pairs_sf)
Exemple #5
0
    def test_equal_to_simple_function_in_single_system_case(self):
        cretention = retention_cls()

        # ----------------------------------------------
        d_target = OrderedDict([(("M1", "A"), 10), (("M2", "A"), 4),
                                (("M3", "A"), 6), (("M4", "A"), 8),
                                (("M5", "A"), 2)])
        keys = list(d_target.keys())

        cretention.load_data_from_target(d_target)
        cretention.make_digraph()
        cretention.dmolecules_inv = cretention.invert_dictionary(
            cretention.dmolecules)
        cretention.dcollections_inv = cretention.invert_dictionary(
            cretention.dcollections)

        d_pairs_ref = {
            0: [],
            1: [(4, 1), (1, 2), (2, 3), (3, 0)],
            2: [(4, 1), (1, 2), (2, 3), (3, 0), (4, 2), (1, 3), (2, 0)],
            3: [(4, 1), (1, 2), (2, 3), (3, 0), (4, 2), (1, 3), (2, 0), (4, 3),
                (1, 0)],
            4: [(4, 1), (1, 2), (2, 3), (3, 0), (4, 2), (1, 3), (2, 0), (4, 3),
                (1, 0), (4, 0)]
        }

        for d in d_pairs_ref.keys():
            pairs_og = get_pairs_from_order_graph(cretention,
                                                  keys,
                                                  allow_overlap=True,
                                                  d_lower=0,
                                                  d_upper=d)
            pairs = get_pairs_single_system(list(d_target.values()),
                                            d_lower=0,
                                            d_upper=d)

            self.assertEqual(len(pairs_og), len(d_pairs_ref[d]))
            self.assertEqual(len(pairs), len(d_pairs_ref[d]))

            for pair in d_pairs_ref[d]:
                self.assertIn(pair, pairs_og)
                self.assertIn(pair, pairs)

        # ----------------------------------------------
        d_target = OrderedDict([(("M1", "A"), 10), (("M2", "A"), 4),
                                (("M3", "A"), 6), (("M4", "A"), 8),
                                (("M5", "A"), 2)])
        keys = list(d_target.keys())

        cretention.load_data_from_target(d_target)
        cretention.make_digraph()
        cretention.dmolecules_inv = cretention.invert_dictionary(
            cretention.dmolecules)
        cretention.dcollections_inv = cretention.invert_dictionary(
            cretention.dcollections)

        d_pairs_ref = {
            5: [],
            4: [(4, 0)],
            3: [(4, 0), (4, 3), (1, 0)],
            2: [(4, 0), (4, 3), (1, 0), (4, 2), (1, 3), (2, 0)],
            1: [(4, 0), (4, 3), (1, 0), (4, 2), (1, 3), (2, 0), (4, 1), (1, 2),
                (2, 3), (3, 0)],
            0: [(4, 0), (4, 3), (1, 0), (4, 2), (1, 3), (2, 0), (4, 1), (1, 2),
                (2, 3), (3, 0)]
        }

        for d in d_pairs_ref.keys():
            pairs_og = get_pairs_from_order_graph(cretention,
                                                  keys,
                                                  allow_overlap=True,
                                                  d_lower=d,
                                                  d_upper=np.inf)
            pairs = get_pairs_single_system(list(d_target.values()),
                                            d_lower=d,
                                            d_upper=np.inf)

            self.assertEqual(len(pairs_og), len(d_pairs_ref[d]))
            self.assertEqual(len(pairs), len(d_pairs_ref[d]))

            for pair in d_pairs_ref[d]:
                self.assertIn(pair, pairs_og)
                self.assertIn(pair, pairs)
Exemple #6
0
    def test_d(self):
        cretention = retention_cls()

        # ----------------------------------------------
        d_target = OrderedDict([(("M1", "A"), 1), (("M2", "A"), 2),
                                (("M3", "A"), 3), (("M4", "A"), 4),
                                (("M2", "B"), 1), (("M3", "B"), 2),
                                (("M7", "B"), 3)])
        keys = list(d_target.keys())

        cretention.load_data_from_target(d_target)
        cretention.make_digraph()
        cretention.dmolecules_inv = cretention.invert_dictionary(
            cretention.dmolecules)
        cretention.dcollections_inv = cretention.invert_dictionary(
            cretention.dcollections)

        d_pairs_ref = {
            0: [],
            1: [(0, 1), (0, 4), (1, 2), (1, 5), (2, 3), (2, 6), (4, 5), (4, 2),
                (5, 3), (5, 6)],
            2: [(0, 1), (0, 4), (1, 2), (1, 5), (2, 3), (2, 6), (4, 5), (4, 2),
                (5, 3), (5, 6), (0, 2), (0, 5), (1, 6), (1, 3), (4, 6),
                (4, 3)],
            3: [(0, 1), (0, 4), (1, 2), (1, 5), (2, 3), (2, 6), (4, 5), (4, 2),
                (5, 3), (5, 6), (0, 2), (0, 5), (1, 6), (1, 3), (4, 6), (4, 3),
                (0, 3), (0, 6)]
        }

        for allow_overlap in [True, False]:
            for d in d_pairs_ref.keys():
                pairs = get_pairs_from_order_graph(cretention,
                                                   keys,
                                                   allow_overlap=allow_overlap,
                                                   d_lower=0,
                                                   d_upper=d)

                self.assertEqual(len(pairs), len(d_pairs_ref[d]))

                for pair in d_pairs_ref[d]:
                    self.assertIn(pair, pairs)

        # ----------------------------------------------
        d_target = OrderedDict([(("M1", "A"), 1), (("M2", "A"), 2),
                                (("M3", "A"), 3), (("M4", "A"), 4),
                                (("M2", "B"), 2), (("M3", "B"), 1)])
        keys = list(d_target.keys())

        cretention.load_data_from_target(d_target)
        cretention.make_digraph()
        cretention.dmolecules_inv = cretention.invert_dictionary(
            cretention.dmolecules)
        cretention.dcollections_inv = cretention.invert_dictionary(
            cretention.dcollections)

        d_pairs_ref = {
            0: [],
            1: [(0, 1), (0, 4), (1, 2), (1, 5), (2, 3), (2, 4), (2, 1), (5, 4),
                (5, 3), (5, 1), (4, 2), (4, 5)],
            2: [(0, 1), (0, 4), (1, 2), (1, 5), (2, 3), (2, 4), (2, 1), (5, 4),
                (5, 3), (5, 1), (4, 2), (4, 5), (0, 2), (0, 5), (1, 3),
                (4, 3)],
            3: [(0, 1), (0, 4), (1, 2), (1, 5), (2, 3), (2, 4), (2, 1), (5, 4),
                (5, 3), (5, 1), (4, 2), (4, 5), (0, 2), (0, 5), (1, 3), (4, 3),
                (0, 3)]
        }

        for d in d_pairs_ref.keys():
            pairs = get_pairs_from_order_graph(cretention,
                                               keys,
                                               allow_overlap=True,
                                               d_lower=0,
                                               d_upper=d)

            self.assertEqual(len(pairs), len(d_pairs_ref[d]))

            for pair in d_pairs_ref[d]:
                self.assertIn(pair, pairs)

        # ----------------------------------------------
        d_target = OrderedDict([(("M1", "A"), 1), (("M2", "A"), 2),
                                (("M3", "A"), 3), (("M4", "A"), 4),
                                (("M2", "B"), 2), (("M3", "B"), 1)])
        keys = list(d_target.keys())

        cretention.load_data_from_target(d_target)
        cretention.make_digraph()
        cretention.dmolecules_inv = cretention.invert_dictionary(
            cretention.dmolecules)
        cretention.dcollections_inv = cretention.invert_dictionary(
            cretention.dcollections)

        d_pairs_ref = {
            0: [],
            1: [(0, 1), (0, 4), (2, 3), (5, 3)],
            2: [(0, 1), (0, 4), (2, 3), (5, 3), (0, 2), (0, 5), (1, 3),
                (4, 3)],
            3: [(0, 1), (0, 4), (2, 3), (5, 3), (0, 2), (0, 5), (1, 3), (4, 3),
                (0, 3)]
        }

        for d in d_pairs_ref.keys():
            pairs = get_pairs_from_order_graph(cretention,
                                               keys,
                                               allow_overlap=False,
                                               d_lower=0,
                                               d_upper=d)

            self.assertEqual(len(pairs), len(d_pairs_ref[d]))

            for pair in d_pairs_ref[d]:
                self.assertIn(pair, pairs)

        # ----------------------------------------------
        d_target = OrderedDict([(("M1", "A"), 1), (("M2", "A"), 2),
                                (("M3", "A"), 3), (("M4", "A"), 4),
                                (("M2", "B"), 2), (("M3", "B"), 1)])
        keys = list(d_target.keys())

        cretention.load_data_from_target(d_target)
        cretention.make_digraph()
        cretention.dmolecules_inv = cretention.invert_dictionary(
            cretention.dmolecules)
        cretention.dcollections_inv = cretention.invert_dictionary(
            cretention.dcollections)

        d_pairs_ref = {
            4: [],
            3: [(0, 3)],
            2: [(0, 2), (0, 3), (0, 5), (1, 3), (4, 3)],
            1: [(0, 1), (0, 4), (2, 3), (5, 3), (0, 2), (0, 5), (1, 3), (4, 3),
                (0, 3)],
            0: [(0, 1), (0, 4), (2, 3), (5, 3), (0, 2), (0, 5), (1, 3), (4, 3),
                (0, 3)]
        }

        for d in d_pairs_ref.keys():
            pairs = get_pairs_from_order_graph(cretention,
                                               keys,
                                               allow_overlap=False,
                                               d_lower=d,
                                               d_upper=np.inf)

            self.assertEqual(len(pairs), len(d_pairs_ref[d]))

            for pair in d_pairs_ref[d]:
                self.assertIn(pair, pairs)

        # ----------------------------------------------
        d_target = OrderedDict([(("M1", "A"), 1), (("M2", "A"), 2),
                                (("M3", "A"), 3), (("M4", "A"), 4),
                                (("M2", "B"), 1), (("M3", "B"), 2),
                                (("M7", "B"), 1.5)])
        keys = list(d_target.keys())

        cretention.load_data_from_target(d_target)
        cretention.make_digraph()
        cretention.dmolecules_inv = cretention.invert_dictionary(
            cretention.dmolecules)
        cretention.dcollections_inv = cretention.invert_dictionary(
            cretention.dcollections)

        d_pairs_ref = {
            4: [],
            3: [(0, 3)],
            2: [(0, 3), (0, 5), (0, 2), (0, 6), (1, 3), (4, 3), (6, 3)],
            1: [(0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (1, 2), (1, 3),
                (1, 5), (1, 6), (2, 3), (4, 2), (4, 3), (4, 5), (4, 6), (5, 3),
                (6, 2), (6, 3), (6, 5)],
            0: [(0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (1, 2), (1, 3),
                (1, 5), (1, 6), (2, 3), (4, 2), (4, 3), (4, 5), (4, 6), (5, 3),
                (6, 2), (6, 3), (6, 5)]
        }

        for allow_overlap in [True, False]:
            for d in d_pairs_ref.keys():
                pairs = get_pairs_from_order_graph(cretention,
                                                   keys,
                                                   allow_overlap=allow_overlap,
                                                   d_lower=d,
                                                   d_upper=np.inf)

                self.assertEqual(len(pairs), len(d_pairs_ref[d]))

                for pair in d_pairs_ref[d]:
                    self.assertIn(pair, pairs)
Exemple #7
0
    def test_bordercases(self):
        cretention = retention_cls()

        # ----------------------------------------------
        d_target = OrderedDict([(("M1", "A"), 1), (("M1", "B"), 2),
                                (("M1", "C"), 3)])
        keys = list(d_target.keys())

        cretention.load_data_from_target(d_target)
        cretention.make_digraph()
        cretention.dmolecules_inv = cretention.invert_dictionary(
            cretention.dmolecules)
        cretention.dcollections_inv = cretention.invert_dictionary(
            cretention.dcollections)

        for allow_overlap in [True, False]:
            pairs = get_pairs_from_order_graph(cretention,
                                               keys,
                                               allow_overlap=allow_overlap,
                                               d_lower=0,
                                               d_upper=np.inf)
            self.assertEqual(len(pairs), 0)

        # ----------------------------------------------
        d_target = OrderedDict([(("M2", "A"), 2), (("M3", "A"), 3),
                                (("M3", "B"), 2), (("M2", "B"), 3)])
        keys = list(d_target.keys())

        cretention.load_data_from_target(d_target)
        cretention.make_digraph()
        cretention.dmolecules_inv = cretention.invert_dictionary(
            cretention.dmolecules)
        cretention.dcollections_inv = cretention.invert_dictionary(
            cretention.dcollections)

        pairs = get_pairs_from_order_graph(cretention,
                                           keys,
                                           allow_overlap=True,
                                           d_lower=0,
                                           d_upper=np.inf)
        self.assertEqual(len(pairs), 8)

        pairs = get_pairs_from_order_graph(cretention,
                                           keys,
                                           allow_overlap=False,
                                           d_lower=0,
                                           d_upper=np.inf)
        self.assertEqual(len(pairs), 0)

        # ----------------------------------------------
        d_target = OrderedDict([(("M1", "A"), 1), (("M2", "A"), 2),
                                (("M3", "A"), 3), (("M4", "A"), 4),
                                (("M2", "B"), 2), (("M3", "B"), 1)])
        keys = list(d_target.keys())

        cretention.load_data_from_target(d_target)
        cretention.make_digraph()
        cretention.dmolecules_inv = cretention.invert_dictionary(
            cretention.dmolecules)
        cretention.dcollections_inv = cretention.invert_dictionary(
            cretention.dcollections)

        pairs = get_pairs_from_order_graph(cretention,
                                           keys,
                                           allow_overlap=True,
                                           d_lower=0,
                                           d_upper=0)
        self.assertEqual(len(pairs), 0)

        pairs = get_pairs_from_order_graph(cretention,
                                           keys,
                                           allow_overlap=False,
                                           d_lower=0,
                                           d_upper=0)
        self.assertEqual(len(pairs), 0)

        # ----------------------------------------------
        d_target = OrderedDict([(("M1", "A"), 1), (("M2", "A"), 2),
                                (("M3", "A"), 3), (("M4", "A"), 4),
                                (("M2", "B"), 2), (("M3", "B"), 1)])
        keys = list(d_target.keys())

        cretention.load_data_from_target(d_target)
        cretention.make_digraph()
        cretention.dmolecules_inv = cretention.invert_dictionary(
            cretention.dmolecules)
        cretention.dcollections_inv = cretention.invert_dictionary(
            cretention.dcollections)

        pairs = get_pairs_from_order_graph(cretention,
                                           keys,
                                           allow_overlap=True,
                                           d_lower=np.inf,
                                           d_upper=np.inf)
        self.assertEqual(len(pairs), 0)

        pairs = get_pairs_from_order_graph(cretention,
                                           keys,
                                           allow_overlap=False,
                                           d_lower=np.inf,
                                           d_upper=np.inf)
        self.assertEqual(len(pairs), 0)
Exemple #8
0
def find_hparan_ranksvm(estimator,
                        X,
                        y,
                        param_grid,
                        cv,
                        pair_params,
                        scaler=None,
                        n_jobs=1,
                        fold_score_aggregation="weighted_average",
                        all_pairs_as_test=True):
    """
    Task: find the hyper-parameter from a set of parameters (param_grid),
          that performs best in an cross-validation setting for the given
          estimator.

    :param estimator: Estimator object, e.g. KernelRankSVC

    :param X: dictionary, (mol-id, system)-tuples as keys and molecular
              features as values:

              Example:
                {("M1", "S1"): feat_11, ...}

    :param y: dictionary, (mol-id, system)-tuples as keys and retention
              times as values

              Example:
                {("M1", "S1"): rt_11, ...}

    :param param_grid: dictionary, defining the grid-search space
        "C": Trade-of parameter for the SVM
        "gamma": width of the rbf/gaussian kernel
        ... etc. ...

        Example:
            {"C": [0.1, 1, 10], "gamma": [0.1, 0.25, 0.5, 1]}

    :param cv: cross-validation generator, see sklearn package, must be
               either a GroupKFold or GroupShuffleSplit object.

    :param pair_params: dictionary, specifying parameters for the order graph:
        "ireverse": scalar, Should cross-system elution transitivity be included
            0: no, 1: yes
        "d_lower": scalar, minimum distance of two molecules in the elution order graph
                   to be considered as a pair.
        "d_upper": scalar, maximum distance of two molecules in the elution order graph
                   to be considered as a pair.
        "allow_overlap": scalar, Should overlap between the upper and lower sets
                         be allowed. Those overlaps originate from retention order
                         contradictions between the different systems.

    :param scaler: scaler object, per feature scaler, e.g. MinMaxScaler

    :param n_jobs: integer, number of jobs run in parallel. Parallelization is performed
        over the cv-folds. (default = 1)

    :fold_score_aggregation: string, (default = "weighted_average")

    :all_pairs_as_test: boolean, should all possible pairs (d_lower = 0, d_upper = np.inf)
        be used during the test. If 'False' than corresponding values are taking from the
        'pair_params' dictionary. (default = True)

    :return: dictionary, containing combination of best parameters
                Example:
                    {"C": 1, "gamma": 0.25}

             dictionary, all parameter combinations with corresponding scores
                 Example:
                    [{"C": 1, "gamma": 0.25, "score": 0.98},
                     {"C": 1, "gamma": 0.50, "score": 0.94},
                     ...]

             scalar, number of pairs used to train the final model

             estimator object, fitted using the best parameters
    """
    if not (isinstance(cv, GroupKFold) or isinstance(cv, GroupShuffleSplit)):
        raise ValueError("Cross-validation generator must be either of "
                         "class 'GroupKFold' or 'GroupShuffleSplit'. "
                         "Provided class is '%s'." % cv.__class__.__name__)

    if len(X) != len(y) or len(X.keys() - y.keys()) or len(y.keys() -
                                                           X.keys()):
        raise ValueError("Keys-set for features and retentions times must "
                         "be equal.")

    # Make a list of all combinations of parameters
    l_params = list(ParameterGrid(param_grid))
    param_scores = np.zeros((len(l_params), ))

    # Get all (mol-id, system)-tuples used for the parameter search
    keys = list(X.keys())

    if len(l_params) > 1:
        mol_ids = list(zip(*keys))[0]
        cv_splits = cv.split(range(len(keys)), groups=mol_ids)

        # Precompute the training / test pairs to save computation time as
        # we do not need to repeat this for several parameter settings.
        pairs_train_sets, pairs_test_sets = [], []
        X_train_sets, X_test_sets = [], []
        n_pairs_test_sets = []

        print("Get pairs for hparam estimation: ", end="", flush=True)
        for k_cv, (train_set, test_set) in enumerate(cv_splits):
            print("%d " % k_cv, end="", flush=True)

            # 0) Get keys (mol-id, system)-tuples, corresponding to the training
            #    and test sets.
            keys_train = [keys[idx] for idx in train_set]
            keys_test = [keys[idx] for idx in test_set]

            # Check for overlap of molecular ids, e.g. InChIs. Between training and test
            # molecular ids should not be shared, e.g. if they appear in different systems
            # at the same time.
            mol_ids_train = [mol_ids[idx] for idx in train_set]
            mol_ids_test = [mol_ids[idx] for idx in test_set]

            if set(mol_ids_train) & set(mol_ids_test):
                if isinstance(cv, GroupKFold) or isinstance(
                        cv, GroupShuffleSplit):
                    raise RuntimeError(
                        "As grouped cross-validation is used the training "
                        "and test molecules, i.e. mol_ids, are not allowed "
                        "to overlap. This can happen if molecular structures "
                        "are appearing in different systems. During the "
                        "learning of hyper-parameter the training set should "
                        "not contain any structure also in the test set.",
                        set(mol_ids_train) & set(mol_ids_test))
                else:
                    print("Training and test keys overlaps.",
                          set(mol_ids_train) & set(mol_ids_test))

            # 1) Extract the target values from y (train and test) using the keys
            y_train, y_test = OrderedDict(), OrderedDict()
            for key in keys_train:
                y_train[key] = y[key]
            for key in keys_test:
                y_test[key] = y[key]

            # 2) Calculate the pairs (train and test)
            cretention_train, cretention_test = retention_cls(), retention_cls(
            )

            #   a) load 'lrows' in the retention_cls
            cretention_train.load_data_from_target(y_train)
            cretention_test.load_data_from_target(y_test)

            #   b) build the digraph
            cretention_train.make_digraph(ireverse=pair_params["ireverse"])
            cretention_test.make_digraph(ireverse=pair_params["ireverse"])

            #   c) find the upper and lower set
            cretention_train.dmolecules_inv = cretention_train.invert_dictionary(
                cretention_train.dmolecules)
            cretention_train.dcollections_inv = cretention_train.invert_dictionary(
                cretention_train.dcollections)
            cretention_test.dmolecules_inv = cretention_test.invert_dictionary(
                cretention_test.dmolecules)
            cretention_test.dcollections_inv = cretention_test.invert_dictionary(
                cretention_test.dcollections)

            #   d) get the pairs from the upper and lower sets
            pairs_train = get_pairs_from_order_graph(
                cretention_train,
                keys_train,
                allow_overlap=pair_params["allow_overlap"],
                n_jobs=n_jobs,
                d_lower=pair_params["d_lower"],
                d_upper=pair_params["d_upper"])
            pairs_train_sets.append(pairs_train)

            if all_pairs_as_test:
                pairs_test = get_pairs_from_order_graph(
                    cretention_test,
                    keys_test,
                    allow_overlap=pair_params["allow_overlap"],
                    n_jobs=n_jobs,
                    d_lower=0,
                    d_upper=np.inf)
            else:
                pairs_test = get_pairs_from_order_graph(
                    cretention_test,
                    keys_test,
                    allow_overlap=pair_params["allow_overlap"],
                    n_jobs=n_jobs,
                    d_lower=pair_params["d_lower"],
                    d_upper=pair_params["d_upper"])

            pairs_test_sets.append(pairs_test)
            n_pairs_test_sets.append(len(pairs_test))

            # 3) Extract the features from X (train and test) using the keys
            X_train_sets.append(np.array([X[key] for key in keys_train]))
            X_test_sets.append(np.array([X[key] for key in keys_test]))

        print("")

        for k_param, param in enumerate(l_params):
            # Calculate the absolute number of correctly classified pairs
            # for each fold.
            fold_scores = Parallel(n_jobs=n_jobs, verbose=False)(
                delayed(_fit_and_score_ranksvm)(param.copy(), clone(
                    estimator), X_train_sets[k_cv], X_test_sets[k_cv],
                                                pairs_train_sets[k_cv],
                                                pairs_test_sets[k_cv], scaler)
                for k_cv in range(cv.get_n_splits()))

            if fold_score_aggregation == "average":
                param_scores[k_param] = np.mean(fold_scores /
                                                np.array(n_pairs_test_sets))
            elif fold_score_aggregation == "weighted_average":
                param_scores[k_param] = np.sum(fold_scores) / np.sum(
                    n_pairs_test_sets)
            else:
                raise ValueError("Invalid fold-scoring aggregation: %s." %
                                 fold_score_aggregation)

    ## Fit model using the best parameters
    # Find the best params
    best_params = l_params[np.argmax(param_scores)].copy()

    # Fit the model using the best parameters
    best_estimator = clone(estimator)
    best_estimator.set_params(**_filter_params(best_params, best_estimator))

    # Build retention order graph
    cretention = retention_cls()
    cretention.load_data_from_target(y)
    cretention.make_digraph(ireverse=pair_params["ireverse"])
    cretention.dmolecules_inv = cretention.invert_dictionary(
        cretention.dmolecules)
    cretention.dcollections_inv = cretention.invert_dictionary(
        cretention.dcollections)

    pairs = get_pairs_from_order_graph(
        cretention,
        keys,
        allow_overlap=pair_params["allow_overlap"],
        n_jobs=n_jobs,
        d_lower=pair_params["d_lower"],
        d_upper=pair_params["d_upper"])
    n_pairs_train = len(pairs)
    X = np.array([X[key] for key in keys])

    if scaler is not None:
        X = scaler.transform(X)

    fit_params = {"FX": X, "pairs": pairs}

    best_estimator.fit(None, y=None, fit_params=fit_params)

    # Combine the mean fold scores with the list of parameter sets
    for k_param, _ in enumerate(l_params):
        l_params[k_param]["score"] = param_scores[k_param]

    return best_params, l_params, n_pairs_train, best_estimator, X, None