Ejemplo n.º 1
0
    def commit(self, path: Optional[Path] = None) -> None:
        if (self.storage_mode == CheckpointStorage.MEMORY or not path
                or not isinstance(self.dir_or_data, dict)):
            return

        source_ip = self.dir_or_data[NODE_IP_KEY]
        source_path = self.dir_or_data[CHECKPOINT_PATH_ON_NODE_KEY]
        target_ip = get_node_ip_address()

        if source_ip == target_ip:
            # Move contents of source_path, but not source_path
            # itself. shutil.move is already recursive.
            for inner in Path(source_path).iterdir():
                shutil.move(str(inner.absolute()), str(path))
            shutil.rmtree(source_path, ignore_errors=True)
        else:
            sync_dir_between_nodes(
                source_ip=source_ip,
                source_path=source_path,
                target_ip=target_ip,
                target_path=str(path),
                return_futures=False,
                max_size_bytes=None,
            )
            delete_on_node(node_ip=source_ip, path=source_path)
        save_preprocessor_to_dir(self.dir_or_data.pop(PREPROCESSOR_KEY, None),
                                 path)
        # add tune checkpoint id
        with open(path.joinpath(TUNE_CHECKPOINT_ID), "w") as f:
            f.write(str(self.id))
Ejemplo n.º 2
0
 def write_checkpoint(self, checkpoint: Dict):
     # If inside a Tune Trainable, then checkpoint with Tune.
     with tune.checkpoint_dir(
             step=self._latest_checkpoint_id) as checkpoint_dir:
         source_ip = checkpoint[NODE_IP_KEY]
         source_path = checkpoint[CHECKPOINT_PATH_ON_NODE_KEY]
         target_ip = get_node_ip_address()
         if source_ip == target_ip:
             # Move contents of source_path, but not source_path
             # itself. shutil.move is already recursive.
             for path in Path(source_path).iterdir():
                 shutil.move(str(path.absolute()), checkpoint_dir)
             shutil.rmtree(source_path, ignore_errors=True)
         else:
             sync_dir_between_nodes(
                 source_ip=source_ip,
                 source_path=source_path,
                 target_ip=target_ip,
                 target_path=checkpoint_dir,
                 return_futures=False,
                 max_size_bytes=None,
             )
             delete_on_node(node_ip=source_ip, path=source_path)
         checkpoint_dir = Path(checkpoint_dir)
         save_preprocessor_to_dir(self.preprocessor, checkpoint_dir)
         # add tune checkpoint id
         with open(checkpoint_dir.joinpath(TUNE_CHECKPOINT_ID), "w") as f:
             f.write(str(self._latest_checkpoint_id))
Ejemplo n.º 3
0
    def _execute_sync(
        self,
        source_tuple: Tuple[str, str],
        target_tuple: Tuple[str, str],
    ) -> bool:
        source_ip, source_path = source_tuple
        target_ip, target_path = target_tuple

        self._sync_future, pack_actor, files_stats = sync_dir_between_nodes(
            source_ip=source_ip,
            source_path=source_path,
            target_ip=target_ip,
            target_path=target_path,
            return_futures=True,
            max_size_bytes=self._max_size_bytes,
        )

        if self._store_remotes:
            self._stored_pack_actor_ref = pack_actor
            self._stored_files_stats = files_stats

        return True
Ejemplo n.º 4
0
def _huggingface_train_loop_per_worker(config):
    """Per-worker training loop for HuggingFace Transformers."""
    trainer_init_per_worker = config.pop("_trainer_init_per_worker")

    # Env vars necessary for HF to setup DDP
    os.environ["RANK"] = str(train.world_rank())
    os.environ["WORLD_SIZE"] = str(train.world_size())
    os.environ["LOCAL_RANK"] = str(train.local_rank())

    train_dataset = train.get_dataset_shard(TRAIN_DATASET_KEY)
    eval_dataset = train.get_dataset_shard(EVALUATION_DATASET_KEY)

    train_torch_dataset, eval_torch_dataset = process_datasets(
        train_dataset,
        eval_dataset,
    )

    trainer: transformers.trainer.Trainer = trainer_init_per_worker(
        train_torch_dataset, eval_torch_dataset, **config)

    if trainer.args.push_to_hub and not trainer.args.hub_token:
        warnings.warn(
            "You have set `push_to_hub=True` but didn't specify `hub_token`. "
            "Pushing to hub will most likely fail, as the credentials will not "
            "be automatically propagated from the local enviroment to the Ray Actors. "
            "If that happens, specify `hub_token` in `TrainingArguments`.")

    if (trainer.args.evaluation_strategy == "steps"
            or trainer.args.save_strategy == "steps"
            or trainer.args.logging_strategy == "steps"):
        raise ValueError(
            "'steps' value for `evaluation_strategy`, `logging_strategy` "
            "or `save_strategy` is not yet supported.")

    trainer = wrap_transformers_trainer(trainer)

    # ensure no HF logging callbacks are added
    # aside from doubling functionality with our callbacks,
    # the Wandb callbacks causes training to freeze
    integration_callbacks = transformers.trainer.get_reporting_integration_callbacks(
        trainer.args.report_to)
    for callback in integration_callbacks:
        trainer.pop_callback(callback)

    trainer.add_callback(TrainReportCallback)

    checkpoint = session.get_checkpoint()
    checkpoint_path = None
    remove_checkpoint_path = False
    if checkpoint:
        assert isinstance(checkpoint, Checkpoint)
        checkpoint_dict = checkpoint.to_dict()
        source_ip = checkpoint_dict[NODE_IP_KEY]
        source_path = checkpoint_dict[CHECKPOINT_PATH_ON_NODE_KEY]
        target_ip = get_node_ip_address()
        if source_ip == target_ip:
            checkpoint_path = source_path
        else:
            checkpoint_path = tempfile.mkdtemp(
                suffix=Path(trainer.args.output_dir).name)
            remove_checkpoint_path = True
            sync_dir_between_nodes(
                source_ip=source_ip,
                source_path=source_path,
                target_ip=target_ip,
                target_path=checkpoint_path,
                return_futures=False,
                max_size_bytes=None,
            )
    trainer.train(resume_from_checkpoint=checkpoint_path)
    if remove_checkpoint_path:
        shutil.rmtree(checkpoint_path, ignore_errors=True)
Ejemplo n.º 5
0
    def testSyncBetweenNodesAndDelete(self):
        temp_source = tempfile.mkdtemp()
        temp_up_target = tempfile.mkdtemp()
        temp_down_target = tempfile.mkdtemp()
        self.addCleanup(shutil.rmtree, temp_source)
        self.addCleanup(shutil.rmtree, temp_up_target, ignore_errors=True)
        self.addCleanup(shutil.rmtree, temp_down_target)

        os.makedirs(os.path.join(temp_source, "dir_level0", "dir_level1"))
        with open(os.path.join(temp_source, "dir_level0", "file_level1.txt"),
                  "w") as f:
            f.write("Data\n")

        def check_dir_contents(path: str):
            assert os.path.exists(os.path.join(path, "dir_level0"))
            assert os.path.exists(
                os.path.join(path, "dir_level0", "dir_level1"))
            assert os.path.exists(
                os.path.join(path, "dir_level0", "file_level1.txt"))
            with open(os.path.join(path, "dir_level0", "file_level1.txt"),
                      "r") as f:
                assert f.read() == "Data\n"

        # Sanity check
        check_dir_contents(temp_source)

        sync_dir_between_nodes(
            source_ip=ray.util.get_node_ip_address(),
            source_path=temp_source,
            target_ip=ray.util.get_node_ip_address(),
            target_path=temp_up_target,
        )

        # Check sync up
        check_dir_contents(temp_up_target)

        # Max size exceeded
        with self.assertRaises(RayTaskError):
            sync_dir_between_nodes(
                source_ip=ray.util.get_node_ip_address(),
                source_path=temp_up_target,
                target_ip=ray.util.get_node_ip_address(),
                target_path=temp_down_target,
                max_size_bytes=2,
            )

        assert not os.listdir(temp_down_target)

        sync_dir_between_nodes(
            source_ip=ray.util.get_node_ip_address(),
            source_path=temp_up_target,
            target_ip=ray.util.get_node_ip_address(),
            target_path=temp_down_target,
        )

        # Check sync down
        check_dir_contents(temp_down_target)

        # Delete in some dir
        delete_on_node(node_ip=ray.util.get_node_ip_address(),
                       path=temp_up_target)

        assert not os.path.exists(temp_up_target)
Ejemplo n.º 6
0
 def _sync_function(self, *args, **kwargs):
     return sync_dir_between_nodes(*args, **kwargs)