예제 #1
0
    def test_split_into_groups(self):
        lst = [1, 2, 3, 4, 5, 6]

        g1 = split_into_groups(lst[:5], 3)
        self.assertEqual(g1, [[1, 2], [3, 4], [5]])

        g2 = split_into_groups(lst, 7)
        self.assertEqual(g2, [[1], [2], [3], [4], [5], [6]])

        g3 = split_into_groups(lst[0:1], 7)
        self.assertEqual(g3, [[1]])

        g4 = split_into_groups(lst, 3)
        self.assertEqual(g4, [[1, 2], [3, 4], [5, 6]])
예제 #2
0
    def get_split_config(self, split_ind, num_splits):
        new_cfg = self.copy()

        groups = split_into_groups(self.train_scenes, num_splits)
        new_cfg.train_scenes = groups[split_ind] if split_ind < len(
            groups) else []

        groups = split_into_groups(self.validation_scenes, num_splits)
        new_cfg.validation_scenes = groups[split_ind] if split_ind < len(
            groups) else []

        if self.test_scenes:
            groups = split_into_groups(self.test_scenes, num_splits)
            new_cfg.test_scenes = groups[split_ind] if split_ind < len(
                groups) else []

        return new_cfg
예제 #3
0
    def save_messages(self, split_ind=0, num_splits=1):
        message_maker = self.config.message_maker.build()

        split_groups = split_into_groups(
            list(zip(self.config.names, self.config.message_uris)), num_splits)
        split_group = split_groups[split_ind]

        for name, message_uri in split_group:
            # Unlike before, we use the message_maker to make the message.
            message = message_maker.make_message(name)
            str_to_file(message, message_uri)
            print('Saved message to {}'.format(message_uri))
예제 #4
0
    def save_messages(self, split_ind=0, num_splits=1):
        # Save a file for each name with a message.

        # The num_splits is the number of parallel jobs to use and
        # split_ind tracks the index of the parallel job. In this case
        # we are splitting on the names/message_uris.
        split_groups = split_into_groups(
            list(zip(self.config.names, self.config.message_uris)), num_splits)
        split_group = split_groups[split_ind]

        for name, message_uri in split_group:
            message = 'hello {}!'.format(name)
            # str_to_file and most functions in the file_system package can
            # read and write transparently to different file systems based on
            # the URI pattern.
            str_to_file(message, message_uri)
            print('Saved message to {}'.format(message_uri))