def main():
    logger.info("Creating splits")
    datetimes = find_datetimes()
    train_set, valid_set, test_set = split.create_split(datetimes)
    logger.info("Slits created")
    split.persist_split(train_set, valid_set, test_set)
    logger.info("Slits persisted")
Ejemplo n.º 2
0
    def test_whenCreateSplit_shouldCreate2015AsTestData(self):
        datetimes = [
            datetime(2013, 2, 1),
            datetime(2014, 2, 1),
            datetime(2015, 2, 1),
            datetime(2015, 2, 1),
        ]

        _, _, test_set = split.create_split(datetimes)

        self.assertEqual(len(test_set), 2)
Ejemplo n.º 3
0
    def test_whenCreateSplit_shouldCreate2014AsValidData(self):
        datetimes = [
            datetime(2013, 2, 1),
            datetime(2014, 2, 1),
            datetime(2014, 2, 1),
            datetime(2015, 2, 1),
        ]

        _, valid_set, _ = split.create_split(datetimes)

        self.assertEqual(len(valid_set), 2)
Ejemplo n.º 4
0
    def test_canSaveAndLoadSplits(self):
        datetimes = self._create_datetimes()
        train_set, valid_set, test_set = split.create_split(datetimes)

        split.persist_split(train_set, valid_set, test_set, dir_path="/tmp")
        loaded_train_set, loaded_valid_set, loaded_test_set = split.load(
            dir_path="/tmp"
        )

        self.assertEqual(train_set, loaded_train_set)
        self.assertEqual(valid_set, loaded_valid_set)
        self.assertEqual(test_set, loaded_test_set)
Ejemplo n.º 5
0
    def test_whenCreateSplit_eachSetIsUnique(self):
        datetimes = self._create_datetimes()
        train_set, valid_set, test_set = split.create_split(datetimes)

        # Important to counter the shuffle.
        train_set = sorted(train_set)
        valid_set = sorted(valid_set)
        test_set = sorted(test_set)

        self.assertTrue(valid_set not in train_set)
        self.assertTrue(test_set not in train_set)
        self.assertTrue(test_set not in valid_set)
Ejemplo n.º 6
0
    def test_whenCreateSplit_shouldUseAllValues(self):
        datetimes = self._create_datetimes()

        train_set, valid_set, test_set = split.create_split(datetimes)

        self.assertTrue(len(datetimes), len(train_set) + len(valid_set) + len(test_set))