def test_sync_checkpoint_previous(tmpdir: py.path.local): inputs, input_file, outputs = create_folder_write_file(tmpdir) cp = SyncCheckpoint(checkpoint_dest=str(outputs), checkpoint_src=str(inputs)) scratch = tmpdir.mkdir("user_scratch") assert cp.restore(str(scratch)) == scratch assert scratch.listdir() == [scratch.join(CHECKPOINT_FILE)] # ensure download is not performed again assert cp.restore("x") == scratch
def test_sync_checkpoint_save_filepath(tmpdir): td_path = Path(tmpdir) cp = SyncCheckpoint(checkpoint_dest=tmpdir) dst_path = td_path.joinpath("test") assert not dst_path.exists() inp = td_path.joinpath("test") with inp.open("wb") as f: f.write(b"blah") cp.save(inp) assert dst_path.exists()
def test_sync_checkpoint_folder(tmpdir: py.path.local): inputs, input_file, outputs = create_folder_write_file(tmpdir) cp = SyncCheckpoint(checkpoint_dest=str(outputs)) # Lets try to restore - should not work! assert not cp.restore("/tmp") # Now save cp.save(Path(str(inputs))) # Expect file in tmpdir expected_dst = outputs.join(CHECKPOINT_FILE) assert outputs.listdir() == [expected_dst]
def test_sync_checkpoint_reader(tmpdir: py.path.local): inputs, input_file, outputs = create_folder_write_file(tmpdir) cp = SyncCheckpoint(checkpoint_dest=str(outputs)) # Lets try to restore - should not work! assert not cp.restore("/tmp") # Now save with input_file.open(mode="rb") as b: cp.save(b) # Expect file in tmpdir expected_dst = outputs.join(SyncCheckpoint.TMP_DST_PATH) assert outputs.listdir() == [expected_dst]
def test_sync_checkpoint_restore_default_path(tmpdir): td_path = Path(tmpdir) dest = td_path.joinpath("dest") dest.mkdir() src = td_path.joinpath("src") src.mkdir() prev = src.joinpath("prev") p = b"prev-bytes" with prev.open("wb") as f: f.write(p) cp = SyncCheckpoint(checkpoint_dest=str(dest), checkpoint_src=str(src)) assert cp.read() == p assert cp._prev_download_path is not None assert cp.restore() == cp._prev_download_path
def test_sync_checkpoint_read_multiple_files(tmpdir): """ Read can only work with one file. """ td_path = Path(tmpdir) dest = td_path.joinpath("dest") dest.mkdir() src = td_path.joinpath("src") src.mkdir() prev = src.joinpath("prev") prev2 = src.joinpath("prev2") p = b"prev-bytes" with prev.open("wb") as f: f.write(p) with prev2.open("wb") as f: f.write(p) cp = SyncCheckpoint(checkpoint_dest=str(dest), checkpoint_src=str(src)) with pytest.raises(ValueError, match="Expected exactly one checkpoint - found 2"): cp.read()
def with_task_sandbox(self) -> Builder: prefix = self.working_directory if isinstance(self.working_directory, utils.AutoDeletingTempDir): prefix = self.working_directory.name task_sandbox_dir = tempfile.mkdtemp(prefix=prefix) p = pathlib.Path(task_sandbox_dir) cp_dir = p.joinpath("__cp") cp_dir.mkdir(exist_ok=True) cp = SyncCheckpoint(checkpoint_dest=str(cp_dir)) b = self.new_builder(self) b.checkpoint = cp b.working_dir = task_sandbox_dir return b
def test_sync_checkpoint_write(tmpdir): td_path = Path(tmpdir) cp = SyncCheckpoint(checkpoint_dest=tmpdir) assert cp.read() is None assert cp.restore() is None dst_path = td_path.joinpath(SyncCheckpoint.TMP_DST_PATH) assert not dst_path.exists() cp.write(b"bytes") assert dst_path.exists()
def test_sync_checkpoint_save_file(tmpdir): td_path = Path(tmpdir) cp = SyncCheckpoint(checkpoint_dest=tmpdir) dst_path = td_path.joinpath(SyncCheckpoint.TMP_DST_PATH) assert not dst_path.exists() inp = td_path.joinpath("test") with inp.open("wb") as f: f.write(b"blah") with inp.open("rb") as f: cp.save(f) assert dst_path.exists() with pytest.raises(ValueError): # Unsupported object cp.save(SyncCheckpoint) # noqa
def test_sync_checkpoint_restore(tmpdir): td_path = Path(tmpdir) dest = td_path.joinpath("dest") dest.mkdir() src = td_path.joinpath("src") src.mkdir() prev = src.joinpath("prev") p = b"prev-bytes" with prev.open("wb") as f: f.write(p) cp = SyncCheckpoint(checkpoint_dest=str(dest), checkpoint_src=str(src)) user_dest = td_path.joinpath("user_dest") with pytest.raises(ValueError): cp.restore(user_dest) user_dest.mkdir() assert cp.restore(user_dest) == user_dest assert cp.restore("other_path") == user_dest
def setup_execution( raw_output_data_prefix: str, checkpoint_path: Optional[str] = None, prev_checkpoint: Optional[str] = None, dynamic_addl_distro: Optional[str] = None, dynamic_dest_dir: Optional[str] = None, ): """ :param raw_output_data_prefix: :param checkpoint_path: :param prev_checkpoint: :param dynamic_addl_distro: Works in concert with the other dynamic arg. If present, indicates that if a dynamic task were to run, it should set fast serialize to true and use these values in FastSerializationSettings :param dynamic_dest_dir: See above. :return: """ exe_project = get_one_of("FLYTE_INTERNAL_EXECUTION_PROJECT", "_F_PRJ") exe_domain = get_one_of("FLYTE_INTERNAL_EXECUTION_DOMAIN", "_F_DM") exe_name = get_one_of("FLYTE_INTERNAL_EXECUTION_ID", "_F_NM") exe_wf = get_one_of("FLYTE_INTERNAL_EXECUTION_WORKFLOW", "_F_WF") exe_lp = get_one_of("FLYTE_INTERNAL_EXECUTION_LAUNCHPLAN", "_F_LP") tk_project = get_one_of("FLYTE_INTERNAL_TASK_PROJECT", "_F_TK_PRJ") tk_domain = get_one_of("FLYTE_INTERNAL_TASK_DOMAIN", "_F_TK_DM") tk_name = get_one_of("FLYTE_INTERNAL_TASK_NAME", "_F_TK_NM") tk_version = get_one_of("FLYTE_INTERNAL_TASK_VERSION", "_F_TK_V") compressed_serialization_settings = os.environ.get(SERIALIZED_CONTEXT_ENV_VAR, "") ctx = FlyteContextManager.current_context() # Create directories user_workspace_dir = ctx.file_access.get_random_local_directory() logger.info(f"Using user directory {user_workspace_dir}") pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True) from flytekit import __version__ as _api_version checkpointer = None if checkpoint_path is not None: checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint) logger.debug(f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}") execution_parameters = ExecutionParameters( execution_id=_identifier.WorkflowExecutionIdentifier( project=exe_project, domain=exe_domain, name=exe_name, ), execution_date=_datetime.datetime.utcnow(), stats=_get_stats( cfg=StatsConfig.auto(), # Stats metric path will be: # registration_project.registration_domain.app.module.task_name.user_stats # and it will be tagged with execution-level values for project/domain/wf/lp prefix=f"{tk_project}.{tk_domain}.{tk_name}.user_stats", tags={ "exec_project": exe_project, "exec_domain": exe_domain, "exec_workflow": exe_wf, "exec_launchplan": exe_lp, "api_version": _api_version, }, ), logging=user_space_logger, tmp_dir=user_workspace_dir, raw_output_prefix=raw_output_data_prefix, checkpoint=checkpointer, task_id=_identifier.Identifier(_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version), ) try: file_access = FileAccessProvider( local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"), raw_output_prefix=raw_output_data_prefix, ) except TypeError: # would be thrown from DataPersistencePlugins.find_plugin logger.error(f"No data plugin found for raw output prefix {raw_output_data_prefix}") raise es = ctx.new_execution_state().with_params( mode=ExecutionState.Mode.TASK_EXECUTION, user_space_params=execution_parameters, ) cb = ctx.new_builder().with_file_access(file_access).with_execution_state(es) if compressed_serialization_settings: ss = SerializationSettings.from_transport(compressed_serialization_settings) ssb = ss.new_builder() ssb.project = exe_project ssb.domain = exe_domain ssb.version = tk_version if dynamic_addl_distro: ssb.fast_serialization_settings = FastSerializationSettings( enabled=True, destination_dir=dynamic_dest_dir, distribution_location=dynamic_addl_distro, ) cb = cb.with_serialization_settings(ssb.build()) with FlyteContextManager.with_context(cb) as ctx: yield ctx