Esempio n. 1
0
def test_modification_triggers_upload(process_mock, upload_file, inotify_mock,
                                      copy2):
    process = process_mock.return_value
    inotify = inotify_mock.return_value

    inotify.add_watch.return_value = 1
    inotify.read.return_value = [
        Event(1, flags.CLOSE_WRITE, "cookie", "file_name")
    ]

    def watch():
        call = process_mock.call_args
        args, kwargs = call
        intermediate_output._watch(kwargs["args"][0], kwargs["args"][1],
                                   kwargs["args"][2], kwargs["args"][3])

    process.start.side_effect = watch

    files.write_success_file()
    intermediate_output.start_sync(S3_BUCKET, REGION)

    inotify.add_watch.assert_called()
    inotify.read.assert_called()
    copy2.assert_called()
    upload_file.assert_called()
Esempio n. 2
0
def test_non_write_ignored(process_mock, upload_file, inotify_mock, copy2):
    process = process_mock.return_value
    inotify = inotify_mock.return_value

    inotify.add_watch.return_value = 1
    mask = flags.CREATE
    for flag in flags:
        if flag is not flags.CLOSE_WRITE and flag is not flags.ISDIR:
            mask = mask | flag
    inotify.read.return_value = [Event(1, mask, "cookie", "file_name")]

    def watch():
        call = process_mock.call_args
        args, kwargs = call
        intermediate_output._watch(kwargs["args"][0], kwargs["args"][1],
                                   kwargs["args"][2], kwargs["args"][3])

    process.start.side_effect = watch

    files.write_success_file()
    intermediate_output.start_sync(S3_BUCKET, REGION)

    inotify.add_watch.assert_called()
    inotify.read.assert_called()
    copy2.assert_not_called()
    upload_file.assert_not_called()
Esempio n. 3
0
def test_new_folders_are_watched(process_mock, upload_file, inotify_mock,
                                 copy2):
    process = process_mock.return_value
    inotify = inotify_mock.return_value

    new_dir = "new_dir"
    new_dir_path = os.path.join(environment.output_intermediate_dir, new_dir)
    inotify.add_watch.return_value = 1
    inotify.read.return_value = [
        Event(1, flags.CREATE | flags.ISDIR, "cookie", new_dir)
    ]

    def watch():
        os.makedirs(new_dir_path)

        call = process_mock.call_args
        args, kwargs = call
        intermediate_output._watch(kwargs["args"][0], kwargs["args"][1],
                                   kwargs["args"][2], kwargs["args"][3])

    process.start.side_effect = watch

    files.write_success_file()
    intermediate_output.start_sync(S3_BUCKET, REGION)

    watch_flags = flags.CLOSE_WRITE | flags.CREATE
    inotify.add_watch.assert_any_call(environment.output_intermediate_dir,
                                      watch_flags)
    inotify.add_watch.assert_any_call(new_dir_path, watch_flags)
    inotify.read.assert_called()
    copy2.assert_not_called()
    upload_file.assert_not_called()
def test_nested_delayed_file():
    os.environ["TRAINING_JOB_NAME"] = _timestamp()
    p = intermediate_output.start_sync(bucket_uri, region)

    os.makedirs(os.path.join(intermediate_path, "dir1"))
    dir1 = os.path.join(intermediate_path, "dir1")

    time.sleep(3)

    os.makedirs(os.path.join(dir1, "dir2"))
    dir2 = os.path.join(dir1, "dir2")

    time.sleep(3)

    file1 = os.path.join(dir2, "file1.txt")
    write_file(file1, "file1")

    os.makedirs(os.path.join(intermediate_path, "dir3"))
    dir3 = os.path.join(intermediate_path, "dir3")

    time.sleep(3)

    file2 = os.path.join(dir3, "file2.txt")
    write_file(file2, "file2")

    files.write_success_file()
    p.join()

    # assert that all files that should be under intermediate are still there
    assert os.path.exists(file1)
    assert os.path.exists(file2)

    # assert file exist in S3
    key_prefix = os.path.join(os.environ.get("TRAINING_JOB_NAME"), "output",
                              "intermediate")
    client = boto3.client("s3", region)
    assert _file_exists_in_s3(
        client,
        os.path.join(key_prefix, os.path.relpath(file1, intermediate_path)))
    assert _file_exists_in_s3(
        client,
        os.path.join(key_prefix, os.path.relpath(file2, intermediate_path)))
def test_large_files():
    os.environ["TRAINING_JOB_NAME"] = _timestamp()
    p = intermediate_output.start_sync(bucket_uri, region)

    file_size = 1024 * 256 * 17  # 17MB

    file = os.path.join(intermediate_path, "file.npy")
    _generate_large_npy_file(file_size, file)

    file_to_modify = os.path.join(intermediate_path, "file_to_modify.npy")
    _generate_large_npy_file(file_size, file_to_modify)
    content_to_assert = _generate_large_npy_file(file_size, file_to_modify)

    files.write_failure_file("Failure!!")
    p.join()

    assert os.path.exists(file)
    assert os.path.exists(file_to_modify)

    key_prefix = os.path.join(os.environ.get("TRAINING_JOB_NAME"), "output",
                              "intermediate")
    client = boto3.client("s3", region)
    assert _file_exists_in_s3(
        client,
        os.path.join(key_prefix, os.path.relpath(file, intermediate_path)))
    assert _file_exists_in_s3(
        client,
        os.path.join(key_prefix,
                     os.path.relpath(file_to_modify, intermediate_path)))

    # check that modified file has
    s3 = boto3.resource("s3", region_name=region)
    key = os.path.join(key_prefix,
                       os.path.relpath(file_to_modify, intermediate_path))
    modified_file = os.path.join(environment.output_dir, "modified_file.npy")
    s3.Bucket(bucket).download_file(key, modified_file)
    assert np.array_equal(np.load(modified_file), content_to_assert)
Esempio n. 6
0
def train():
    """The main function responsible for running training in the container."""
    intermediate_sync = None
    exit_code = SUCCESS_CODE
    try:
        env = environment.Environment()

        region = os.environ.get("AWS_REGION",
                                os.environ.get(params.REGION_NAME_ENV))
        s3_endpoint_url = os.environ.get(params.S3_ENDPOINT_URL, None)
        intermediate_sync = intermediate_output.start_sync(
            env.sagemaker_s3_output(), region, endpoint_url=s3_endpoint_url)

        if env.framework_module:
            framework_name, entry_point_name = env.framework_module.split(":")

            framework = importlib.import_module(framework_name)

            # the logger is configured after importing the framework library, allowing
            # the framework to configure logging at import time.
            logging_config.configure_logger(env.log_level)
            logger.info("Imported framework %s", framework_name)
            entrypoint = getattr(framework, entry_point_name)
            entrypoint()
        else:
            logging_config.configure_logger(env.log_level)

            mpi_enabled = env.additional_framework_parameters.get(
                params.MPI_ENABLED)
            runner_type = (runner.RunnerType.MPI if mpi_enabled and
                           (env.current_instance_group
                            in env.distribution_instance_groups) else
                           runner.RunnerType.Process)

            entry_point.run(
                env.module_dir,
                env.user_entry_point,
                env.to_cmd_args(),
                env.to_env_vars(),
                runner_type=runner_type,
            )
        logger.info("Reporting training SUCCESS")

        files.write_success_file()
    except errors.ClientError as e:

        failure_msg = str(e)
        files.write_failure_file(failure_msg)
        logger.error("Reporting training FAILURE")

        logger.error(failure_msg)

        if intermediate_sync:
            intermediate_sync.join()

        exit_code = DEFAULT_FAILURE_CODE
    except Exception as e:  # pylint: disable=broad-except
        failure_msg = "Framework Error: \n%s\n%s" % (traceback.format_exc(),
                                                     str(e))

        files.write_failure_file(failure_msg)
        logger.error("Reporting training FAILURE")

        logger.error(failure_msg)

        error_number = getattr(e, "errno", DEFAULT_FAILURE_CODE)
        exit_code = _get_valid_failure_exit_code(error_number)
    finally:
        if intermediate_sync:
            intermediate_sync.join()
        _exit_processes(exit_code)
def test_intermediate_upload():
    os.environ["TRAINING_JOB_NAME"] = _timestamp()
    p = intermediate_output.start_sync(bucket_uri, region)

    file1 = os.path.join(intermediate_path, "file1.txt")
    write_file(file1, "file1!")

    os.makedirs(os.path.join(intermediate_path, "dir1", "dir2", "dir3"))

    dir1 = os.path.join(intermediate_path, "dir1")
    dir2 = os.path.join(dir1, "dir2")
    dir3 = os.path.join(dir2, "dir3")
    file2 = os.path.join(dir1, "file2.txt")
    file3 = os.path.join(dir2, "file3.txt")
    file4 = os.path.join(dir3, "file4.txt")
    write_file(file2, "dir1_file2!")
    write_file(file3, "dir2_file3!")
    write_file(file4, "dir1_file4!")

    dir_to_delete1 = os.path.join(dir1, "dir4")
    file_to_delete1 = os.path.join(dir_to_delete1, "file_to_delete1.txt")
    os.makedirs(dir_to_delete1)
    write_file(file_to_delete1, "file_to_delete1!")
    os.remove(file_to_delete1)
    os.removedirs(dir_to_delete1)

    file_to_delete2_but_copy = os.path.join(intermediate_path,
                                            "file_to_delete2_but_copy.txt")
    write_file(file_to_delete2_but_copy, "file_to_delete2!")
    time.sleep(1)
    os.remove(file_to_delete2_but_copy)

    file_to_modify1 = os.path.join(dir3, "file_to_modify1.txt")
    write_file(file_to_modify1, "dir3_file_to_modify1_1!")
    write_file(file_to_modify1, "dir3_file_to_modify1_2!")
    write_file(file_to_modify1, "dir3_file_to_modify1_3!")
    content_to_assert = "dir3_file_to_modify1_4!"
    write_file(file_to_modify1, content_to_assert)

    # the last file to be moved
    file5 = os.path.join(intermediate_path, "file5.txt")
    write_file(file5, "file5!")

    files.write_success_file()

    p.join()

    # shouldn't be moved
    file6 = os.path.join(intermediate_path, "file6.txt")
    write_file(file6, "file6!")

    # assert that all files that should be under intermediate are still there
    assert os.path.exists(file1)
    assert os.path.exists(file2)
    assert os.path.exists(file3)
    assert os.path.exists(file4)
    assert os.path.exists(file5)
    assert os.path.exists(file6)
    assert os.path.exists(file_to_modify1)
    # and all the deleted folders and files aren't there
    assert not os.path.exists(dir_to_delete1)
    assert not os.path.exists(file_to_delete1)
    assert not os.path.exists(file_to_delete2_but_copy)

    # assert files exist in S3
    key_prefix = os.path.join(os.environ.get("TRAINING_JOB_NAME"), "output",
                              "intermediate")
    client = boto3.client("s3", region)
    assert _file_exists_in_s3(
        client,
        os.path.join(key_prefix, os.path.relpath(file1, intermediate_path)))
    assert _file_exists_in_s3(
        client,
        os.path.join(key_prefix, os.path.relpath(file2, intermediate_path)))
    assert _file_exists_in_s3(
        client,
        os.path.join(key_prefix, os.path.relpath(file3, intermediate_path)))
    assert _file_exists_in_s3(
        client,
        os.path.join(key_prefix, os.path.relpath(file4, intermediate_path)))
    assert _file_exists_in_s3(
        client,
        os.path.join(key_prefix, os.path.relpath(file5, intermediate_path)))
    assert _file_exists_in_s3(
        client,
        os.path.join(key_prefix,
                     os.path.relpath(file_to_modify1, intermediate_path)))
    deleted_file = os.path.join(
        key_prefix, os.path.relpath(file_to_delete2_but_copy,
                                    intermediate_path))
    assert _file_exists_in_s3(client, deleted_file)
    assert not _file_exists_in_s3(
        client,
        os.path.join(key_prefix,
                     os.path.relpath(dir_to_delete1, intermediate_path)))
    assert not _file_exists_in_s3(
        client,
        os.path.join(key_prefix,
                     os.path.relpath(file_to_delete1, intermediate_path)))
    assert not _file_exists_in_s3(
        client,
        os.path.join(key_prefix, os.path.relpath(file6, intermediate_path)))

    # check that modified file has
    s3 = boto3.resource("s3", region_name=region)
    key = os.path.join(key_prefix,
                       os.path.relpath(file_to_modify1, intermediate_path))
    modified_file = os.path.join(environment.output_dir, "modified_file.txt")
    s3.Bucket(bucket).download_file(key, modified_file)
    with open(modified_file) as f:
        content = f.read()
        assert content == content_to_assert
Esempio n. 8
0
def test_daemon_process():
    intemediate_sync = intermediate_output.start_sync(S3_BUCKET, REGION)
    assert intemediate_sync.daemon is True
Esempio n. 9
0
def test_wrong_output():
    with pytest.raises(ValueError) as e:
        intermediate_output.start_sync("tcp://my/favorite/url", REGION)
    assert "Expecting 's3' scheme" in str(e)
Esempio n. 10
0
def test_accept_file_output_no_process():
    intemediate_sync = intermediate_output.start_sync(
        "file://my/favorite/file", REGION)
    assert intemediate_sync is None