コード例 #1
0
    def test_open_as_container(self):
        # Create a container for testing
        chainerio.set_root("posix")
        zip_file_name = "test"
        zip_file_path = zip_file_name + ".zip"

        shutil.make_archive(zip_file_name, "zip", base_dir=self.dir_name)

        with chainerio.open_as_container(zip_file_path) as container:
            file_generator = container.list()
            file_list = list(file_generator)
            self.assertIn(self.dir_name[:-1], file_list)
            self.assertNotIn(self.tmpfile_path, file_list)
            self.assertNotIn("", file_list)

            file_generator = container.list(self.dir_name)
            file_list = list(file_generator)
            self.assertNotIn(self.dir_name[:-1], file_list)
            self.assertIn(os.path.basename(self.tmpfile_path), file_list)
            self.assertNotIn("", file_list)

            self.assertTrue(container.isdir(self.dir_name))
            self.assertFalse(container.isdir(self.tmpfile_path))

            self.assertIsInstance(container.info(), str)
            with container.open(self.tmpfile_path, "r") as f:
                self.assertEqual(
                    f.read(), self.test_string_str)

        chainerio.remove(zip_file_path)
コード例 #2
0
    def test_fs_detection_on_container_hdfs(self):
        # Create a container for testing
        zip_file_name = "test"
        zip_file_path = zip_file_name + ".zip"

        # TODO(tianqi): add functionality ot chainerio
        from pyarrow import hdfs

        conn = hdfs.connect()
        hdfs_home = conn.info('.')['path']
        conn.close()

        hdfs_file_path = os.path.join(hdfs_home, zip_file_path)

        shutil.make_archive(zip_file_name, "zip", base_dir=self.dir_name)

        with chainerio.open(hdfs_file_path, "wb") as hdfs_file:
            with chainerio.open(zip_file_path, "rb") as posix_file:
                hdfs_file.write(posix_file.read())

        with chainerio.open_as_container(hdfs_file_path) as container:
            with container.open(self.tmpfile_path, "r") as f:
                self.assertEqual(
                    f.read(), self.test_string_str)

        chainerio.remove(zip_file_path)
        chainerio.remove(hdfs_file_path)
コード例 #3
0
ファイル: test_context.py プロジェクト: shu65/chainerio
    def test_open_as_container(self):
        # Create a container for testing
        chainerio.set_root("posix")
        zip_file_name = "test"
        zip_file_path = zip_file_name + ".zip"

        # in the zip, the leading slash will be removed
        # TODO(tianqi): related to issue #61
        dirname_zip = self.tmpdir.name.lstrip('/') + '/'
        file_name_zip = self.tmpfile_path.lstrip('/')
        first_level_dir = dirname_zip.split('/')[0]

        shutil.make_archive(zip_file_name, "zip", base_dir=self.tmpdir.name)

        with chainerio.open_as_container(zip_file_path) as container:
            file_generator = container.list()
            file_list = list(file_generator)
            self.assertIn(first_level_dir, file_list)
            self.assertNotIn(file_name_zip, file_list)
            self.assertNotIn("", file_list)

            file_generator = container.list(dirname_zip)
            file_list = list(file_generator)
            self.assertNotIn(first_level_dir, file_list)
            self.assertIn(os.path.basename(file_name_zip), file_list)
            self.assertNotIn("", file_list)

            self.assertTrue(container.isdir(dirname_zip))
            self.assertFalse(container.isdir(file_name_zip))

            self.assertIsInstance(container.info(), str)
            with container.open(file_name_zip, "r") as f:
                self.assertEqual(f.read(), self.test_string_str)

        chainerio.remove(zip_file_path)
コード例 #4
0
    def init_dataloader(self, epoch, pool, rng=None):
        rng = rng or random
        if not self.args.resume_from_checkpoint or epoch > 0 or \
                self.args.phase2:
            with chio.open_as_container(self.args.input_file) as input_file:
                files = [f for f in input_file.list() if "training" in f]
            files.sort()
            num_files = len(files)
            rng.shuffle(files)
            f_start_id = 0
        else:
            f_start_id = self.snapshot.f_id
            files = self.snapshot.files
            self.args.resume_from_checkpoint = False
            num_files = len(files)

        if torch.distributed.is_initialized() and \
                self.team_size > num_files:
            remainder = self.team_size % num_files
            data_file = files[(f_start_id * self.team_size + self.team_rank +
                               remainder * f_start_id) % num_files]
        else:
            data_file = files[(f_start_id * self.team_size + self.team_rank) %
                              len(files)]

        return pool.submit(create_pretraining_dataset, self.args.input_file,
                           data_file, self.args.max_predictions_per_seq,
                           self.args), f_start_id, files, data_file
コード例 #5
0
    def test_fs_detection_on_container_posix(self):
        # Create a container for testing
        zip_file_name = "test"
        zip_file_path = zip_file_name + ".zip"
        posix_file_path = "file://" + zip_file_path

        shutil.make_archive(zip_file_name, "zip", base_dir=self.dir_name)

        with chainerio.open_as_container(posix_file_path) as container:
            with container.open(self.tmpfile_path, "r") as f:
                self.assertEqual(
                    f.read(), self.test_string_str)

        chainerio.remove(zip_file_path)
コード例 #6
0
def create_pretraining_dataset(container_file, input_file, max_pred_length,
                               args):
    with chio.open_as_container(container_file) as container:
        with container.open(input_file, "rb") as file:
            train_data = pretraining_dataset(input_file,
                                             file,
                                             max_pred_length=max_pred_length)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  num_workers=0,
                                  pin_memory=True)

    return train_dataloader, input_file
コード例 #7
0
ファイル: test_context.py プロジェクト: shu65/chainerio
    def test_fs_detection_on_container_posix(self):
        # Create a container for testing
        zip_file_name = "test"
        zip_file_path = zip_file_name + ".zip"
        posix_file_path = "file://" + zip_file_path

        # in the zip, the leading slash will be removed
        file_name_zip = self.tmpfile_path.lstrip('/')

        shutil.make_archive(zip_file_name, "zip", base_dir=self.tmpdir.name)

        with chainerio.open_as_container(posix_file_path) as container:
            with container.open(file_name_zip, "r") as f:
                self.assertEqual(f.read(), self.test_string_str)

        chainerio.remove(zip_file_path)