Example #1
0
def test_sync_checkpoint_previous(tmpdir: py.path.local):
    inputs, input_file, outputs = create_folder_write_file(tmpdir)
    cp = SyncCheckpoint(checkpoint_dest=str(outputs),
                        checkpoint_src=str(inputs))
    scratch = tmpdir.mkdir("user_scratch")
    assert cp.restore(str(scratch)) == scratch
    assert scratch.listdir() == [scratch.join(CHECKPOINT_FILE)]

    # ensure download is not performed again
    assert cp.restore("x") == scratch
Example #2
0
def test_sync_checkpoint_save_filepath(tmpdir):
    td_path = Path(tmpdir)
    cp = SyncCheckpoint(checkpoint_dest=tmpdir)
    dst_path = td_path.joinpath("test")
    assert not dst_path.exists()
    inp = td_path.joinpath("test")
    with inp.open("wb") as f:
        f.write(b"blah")
    cp.save(inp)
    assert dst_path.exists()
Example #3
0
def test_sync_checkpoint_folder(tmpdir: py.path.local):
    inputs, input_file, outputs = create_folder_write_file(tmpdir)
    cp = SyncCheckpoint(checkpoint_dest=str(outputs))
    # Lets try to restore - should not work!
    assert not cp.restore("/tmp")
    # Now save
    cp.save(Path(str(inputs)))
    # Expect file in tmpdir
    expected_dst = outputs.join(CHECKPOINT_FILE)
    assert outputs.listdir() == [expected_dst]
Example #4
0
def test_sync_checkpoint_reader(tmpdir: py.path.local):
    inputs, input_file, outputs = create_folder_write_file(tmpdir)
    cp = SyncCheckpoint(checkpoint_dest=str(outputs))
    # Lets try to restore - should not work!
    assert not cp.restore("/tmp")
    # Now save
    with input_file.open(mode="rb") as b:
        cp.save(b)
    # Expect file in tmpdir
    expected_dst = outputs.join(SyncCheckpoint.TMP_DST_PATH)
    assert outputs.listdir() == [expected_dst]
Example #5
0
def test_sync_checkpoint_restore_default_path(tmpdir):
    td_path = Path(tmpdir)
    dest = td_path.joinpath("dest")
    dest.mkdir()
    src = td_path.joinpath("src")
    src.mkdir()
    prev = src.joinpath("prev")
    p = b"prev-bytes"
    with prev.open("wb") as f:
        f.write(p)
    cp = SyncCheckpoint(checkpoint_dest=str(dest), checkpoint_src=str(src))
    assert cp.read() == p
    assert cp._prev_download_path is not None
    assert cp.restore() == cp._prev_download_path
Example #6
0
def test_sync_checkpoint_read_multiple_files(tmpdir):
    """
    Read can only work with one file.
    """
    td_path = Path(tmpdir)
    dest = td_path.joinpath("dest")
    dest.mkdir()
    src = td_path.joinpath("src")
    src.mkdir()
    prev = src.joinpath("prev")
    prev2 = src.joinpath("prev2")
    p = b"prev-bytes"
    with prev.open("wb") as f:
        f.write(p)
    with prev2.open("wb") as f:
        f.write(p)
    cp = SyncCheckpoint(checkpoint_dest=str(dest), checkpoint_src=str(src))

    with pytest.raises(ValueError,
                       match="Expected exactly one checkpoint - found 2"):
        cp.read()
Example #7
0
 def with_task_sandbox(self) -> Builder:
     prefix = self.working_directory
     if isinstance(self.working_directory, utils.AutoDeletingTempDir):
         prefix = self.working_directory.name
     task_sandbox_dir = tempfile.mkdtemp(prefix=prefix)
     p = pathlib.Path(task_sandbox_dir)
     cp_dir = p.joinpath("__cp")
     cp_dir.mkdir(exist_ok=True)
     cp = SyncCheckpoint(checkpoint_dest=str(cp_dir))
     b = self.new_builder(self)
     b.checkpoint = cp
     b.working_dir = task_sandbox_dir
     return b
Example #8
0
def test_sync_checkpoint_write(tmpdir):
    td_path = Path(tmpdir)
    cp = SyncCheckpoint(checkpoint_dest=tmpdir)
    assert cp.read() is None
    assert cp.restore() is None
    dst_path = td_path.joinpath(SyncCheckpoint.TMP_DST_PATH)
    assert not dst_path.exists()
    cp.write(b"bytes")
    assert dst_path.exists()
Example #9
0
def test_sync_checkpoint_save_file(tmpdir):
    td_path = Path(tmpdir)
    cp = SyncCheckpoint(checkpoint_dest=tmpdir)
    dst_path = td_path.joinpath(SyncCheckpoint.TMP_DST_PATH)
    assert not dst_path.exists()
    inp = td_path.joinpath("test")
    with inp.open("wb") as f:
        f.write(b"blah")
    with inp.open("rb") as f:
        cp.save(f)
    assert dst_path.exists()

    with pytest.raises(ValueError):
        # Unsupported object
        cp.save(SyncCheckpoint)  # noqa
Example #10
0
def test_sync_checkpoint_restore(tmpdir):
    td_path = Path(tmpdir)
    dest = td_path.joinpath("dest")
    dest.mkdir()
    src = td_path.joinpath("src")
    src.mkdir()
    prev = src.joinpath("prev")
    p = b"prev-bytes"
    with prev.open("wb") as f:
        f.write(p)
    cp = SyncCheckpoint(checkpoint_dest=str(dest), checkpoint_src=str(src))
    user_dest = td_path.joinpath("user_dest")

    with pytest.raises(ValueError):
        cp.restore(user_dest)

    user_dest.mkdir()
    assert cp.restore(user_dest) == user_dest
    assert cp.restore("other_path") == user_dest
Example #11
0
def setup_execution(
    raw_output_data_prefix: str,
    checkpoint_path: Optional[str] = None,
    prev_checkpoint: Optional[str] = None,
    dynamic_addl_distro: Optional[str] = None,
    dynamic_dest_dir: Optional[str] = None,
):
    """

    :param raw_output_data_prefix:
    :param checkpoint_path:
    :param prev_checkpoint:
    :param dynamic_addl_distro: Works in concert with the other dynamic arg. If present, indicates that if a dynamic
      task were to run, it should set fast serialize to true and use these values in FastSerializationSettings
    :param dynamic_dest_dir: See above.
    :return:
    """
    exe_project = get_one_of("FLYTE_INTERNAL_EXECUTION_PROJECT", "_F_PRJ")
    exe_domain = get_one_of("FLYTE_INTERNAL_EXECUTION_DOMAIN", "_F_DM")
    exe_name = get_one_of("FLYTE_INTERNAL_EXECUTION_ID", "_F_NM")
    exe_wf = get_one_of("FLYTE_INTERNAL_EXECUTION_WORKFLOW", "_F_WF")
    exe_lp = get_one_of("FLYTE_INTERNAL_EXECUTION_LAUNCHPLAN", "_F_LP")

    tk_project = get_one_of("FLYTE_INTERNAL_TASK_PROJECT", "_F_TK_PRJ")
    tk_domain = get_one_of("FLYTE_INTERNAL_TASK_DOMAIN", "_F_TK_DM")
    tk_name = get_one_of("FLYTE_INTERNAL_TASK_NAME", "_F_TK_NM")
    tk_version = get_one_of("FLYTE_INTERNAL_TASK_VERSION", "_F_TK_V")

    compressed_serialization_settings = os.environ.get(SERIALIZED_CONTEXT_ENV_VAR, "")

    ctx = FlyteContextManager.current_context()
    # Create directories
    user_workspace_dir = ctx.file_access.get_random_local_directory()
    logger.info(f"Using user directory {user_workspace_dir}")
    pathlib.Path(user_workspace_dir).mkdir(parents=True, exist_ok=True)
    from flytekit import __version__ as _api_version

    checkpointer = None
    if checkpoint_path is not None:
        checkpointer = SyncCheckpoint(checkpoint_dest=checkpoint_path, checkpoint_src=prev_checkpoint)
        logger.debug(f"Checkpointer created with source {prev_checkpoint} and dest {checkpoint_path}")

    execution_parameters = ExecutionParameters(
        execution_id=_identifier.WorkflowExecutionIdentifier(
            project=exe_project,
            domain=exe_domain,
            name=exe_name,
        ),
        execution_date=_datetime.datetime.utcnow(),
        stats=_get_stats(
            cfg=StatsConfig.auto(),
            # Stats metric path will be:
            # registration_project.registration_domain.app.module.task_name.user_stats
            # and it will be tagged with execution-level values for project/domain/wf/lp
            prefix=f"{tk_project}.{tk_domain}.{tk_name}.user_stats",
            tags={
                "exec_project": exe_project,
                "exec_domain": exe_domain,
                "exec_workflow": exe_wf,
                "exec_launchplan": exe_lp,
                "api_version": _api_version,
            },
        ),
        logging=user_space_logger,
        tmp_dir=user_workspace_dir,
        raw_output_prefix=raw_output_data_prefix,
        checkpoint=checkpointer,
        task_id=_identifier.Identifier(_identifier.ResourceType.TASK, tk_project, tk_domain, tk_name, tk_version),
    )

    try:
        file_access = FileAccessProvider(
            local_sandbox_dir=tempfile.mkdtemp(prefix="flyte"),
            raw_output_prefix=raw_output_data_prefix,
        )
    except TypeError:  # would be thrown from DataPersistencePlugins.find_plugin
        logger.error(f"No data plugin found for raw output prefix {raw_output_data_prefix}")
        raise

    es = ctx.new_execution_state().with_params(
        mode=ExecutionState.Mode.TASK_EXECUTION,
        user_space_params=execution_parameters,
    )
    cb = ctx.new_builder().with_file_access(file_access).with_execution_state(es)

    if compressed_serialization_settings:
        ss = SerializationSettings.from_transport(compressed_serialization_settings)
        ssb = ss.new_builder()
        ssb.project = exe_project
        ssb.domain = exe_domain
        ssb.version = tk_version
        if dynamic_addl_distro:
            ssb.fast_serialization_settings = FastSerializationSettings(
                enabled=True,
                destination_dir=dynamic_dest_dir,
                distribution_location=dynamic_addl_distro,
            )
        cb = cb.with_serialization_settings(ssb.build())

    with FlyteContextManager.with_context(cb) as ctx:
        yield ctx