コード例 #1
0
def worker():
    logging.basicConfig(stream=sys.stdout)

    # Parse the arguments + initialize state
    worker_config = worker_lib.WorkerConfig()

    # Delete the current file + toolkit as both got injected
    os.remove(__file__)
    shutil.rmtree("./worker_toolkit")

    # Run the shell / cmd line command
    if "SSM_CMD_LINE" in worker_config.hps:
        cmd_line = worker_config.hps["SSM_CMD_LINE"]
        logger.info(f"Launching: {cmd_line}")
        shell_cmd = subprocess.run(cmd_line)
    elif "SSM_SHELL_CMD_LINE" in worker_config.hps:
        cmd_line = worker_config.hps["SSM_SHELL_CMD_LINE"]
        logger.info(f"Launching a shell: {cmd_line}")
        shell_cmd = subprocess.run(cmd_line,
                                   shell=True,
                                   executable="/bin/bash")

    logger.info(f"finished with {shell_cmd.returncode} return code!")

    # wait_for_state_sync(worker_config)
    return shell_cmd.returncode
コード例 #2
0
ファイル: cifar10.py プロジェクト: shiftan/simple_sagemaker
def main():
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    # 1. Get the worker configuration and
    worker_config = None
    if "SAGEMAKER_JOB_NAME" in os.environ:
        from worker_toolkit import worker_lib

        worker_config = worker_lib.WorkerConfig(per_instance_state=False)

    args = parseArgs()

    if args.download_only:
        # 2a. Update the parsed command line arguments from the configuration
        if worker_config:
            args.data_path = worker_config.state
        download_data(args.data_path)
        return

    # 2b. Update the parsed command line arguments from the configuration
    if worker_config:
        worker_config.initMultiWorkersState()
        # Update paths
        args.state_path = worker_config.instance_state
        args.data_path = worker_config.channel_cifar_data
        args.model_path = worker_config.model_dir
        # Update distribution parameters
        args.num_nodes = worker_config.num_nodes
        args.host_rank = worker_config.host_rank

    os.makedirs(args.data_path, exist_ok=True)
    os.makedirs(args.state_path, exist_ok=True)

    if args.distributed:
        logger.info("*** Distributed training")
        # Initialize the distributed environment.
        if not worker_config:
            # 3b. The are set automatically when using a PyTorch framework
            #   see https://github.com/aws/sagemaker-pytorch-training-toolkit/blob/
            #           master/src/sagemaker_pytorch_container/training.py
            os.environ["MASTER_ADDR"] = "localhost"
            os.environ["MASTER_PORT"] = "7777"

        os.environ["WORLD_SIZE"] = str(args.num_nodes)
        dist.init_process_group(backend=args.backend, rank=args.host_rank)

    else:
        logger.info("*** Single node training")

    train(
        args.data_path,
        args.state_path,
        args.model_path,
        num_workers=args.num_workers,
        train_batch_size=args.train_batch_size,
        test_batch_size=args.test_batch_size,
        epochs=args.epochs,
    )
コード例 #3
0
def worker():
    logging.basicConfig(stream=sys.stdout)
    # parse the arguments
    worker_config = worker_lib.WorkerConfig()
    # get the instance specific state path
    show_inputs(worker_config)

    if int(worker_config.hps["task_type"]) == 1:
        worker1(worker_config)
    elif int(worker_config.hps["task_type"]) == 2:
        worker2(worker_config)

    show_output(worker_config)

    logger.info("finished!")
コード例 #4
0
def main():
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    logger.info("Starting worker...")
    # parse the arguments
    worker_config = worker_lib.WorkerConfig()

    logger.info(f"Hyperparams: {worker_config.hps}")
    logger.info(
        f"Input data files: {list(Path(worker_config.channel_data).rglob('*'))}"
    )
    logger.info(f"State files: { list(Path(worker_config.state).rglob('*'))}")

    if int(worker_config.hps["task_type"]) == 1:
        task1(worker_config)
    elif int(worker_config.hps["task_type"]) == 2:
        task2(worker_config)

    logger.info("finished!")
コード例 #5
0
ファイル: example.py プロジェクト: shiftan/simple_sagemaker
def worker():
    from worker_toolkit import worker_lib

    logger.info("Starting worker...")
    # parse the arguments
    worker_config = worker_lib.WorkerConfig()

    logger.info(f"Hyperparams: {worker_config.hps}")
    logger.info(
        f"Input data files: {list(Path(worker_config.channel_data).rglob('*'))}"
    )
    logger.info(f"State files: { list(Path(worker_config.state).rglob('*'))}")

    if int(worker_config.hps["task"]) == 1:
        # update the state per running instance
        open(
            f"{worker_config.instance_state}/state_{worker_config.current_host}",
            "wt").write("state")
        # write to the model output directory
        for file in Path(worker_config.channel_data).rglob("*"):
            if file.is_file():
                relp = file.relative_to(worker_config.channel_data)
                path = Path(worker_config.model_dir) / (
                    str(relp) + "_proc_by_" + worker_config.current_host)
                path.write_text(file.read_text() + " processed by " +
                                worker_config.current_host)
        open(f"{worker_config.model_dir}/output_{worker_config.current_host}",
             "wt").write("output")
    elif int(worker_config.hps["task"]) == 2:
        logger.info(
            f"Input task2_data: {list(Path(worker_config.channel_task2_data).rglob('*'))}"
        )
        logger.info(
            f"Input task2_data_dist: {list(Path(worker_config.channel_task2_data_dist).rglob('*'))}"
        )

    logger.info("finished!")
コード例 #6
0
    listDir("/opt/ml")
    listDir(worker_config.state)


def logAfter(worker_config):
    # just to show the final directory structue
    listDir("/opt/ml")
    listDir(worker_config.state)


if __name__ == "__main__":
    logging.basicConfig(stream=sys.stdout)
    logger.info("Starting algo...")

    # parse the arguments
    worker_config = worker_lib.WorkerConfig()
    logBefore(worker_config)

    output_data_dir = os.path.join(worker_config.output_data_dir,
                                   worker_config.current_host)

    # create some data in the state dir
    if worker_config.hps["stage"] == 1:
        # put some files in the state directory
        for i in range(10):
            open(
                f"{worker_config.instance_state}/state_{worker_config.current_host}_{i+1}",
                "wt",
            ).write("state")

        # put something in the model
コード例 #7
0
import logging
import subprocess
import sys

from worker_toolkit import worker_lib

logger = logging.getLogger(__name__)


def listDir(path):
    logger.info(f"*** START listing files in {path}")
    logger.info(
        subprocess.run(["ls", "-la", "-R", path],
                       stdout=subprocess.PIPE,
                       universal_newlines=True).stdout)
    logger.info(f"*** END file listing {path}")


if __name__ == "__main__":
    logging.basicConfig(stream=sys.stdout)
    worker_config = worker_lib.WorkerConfig(False)
    listDir(worker_config.channel_data)
    listDir(worker_config.channel_bucket)