コード例 #1
0
def test_smmodelparallel_mnist_multigpu_multinode(ecr_image, instance_type,
                                                  py_version,
                                                  sagemaker_session, tmpdir):
    """
    Tests pt mnist command via script mode
    """
    instance_type = "ml.p3.16xlarge"
    _, image_framework_version = get_framework_and_version_from_tag(ecr_image)
    image_cuda_version = get_cuda_version_from_tag(ecr_image)
    if not (Version(image_framework_version)
            in SpecifierSet(">=1.6,<1.8")) or image_cuda_version != "cu110":
        pytest.skip(
            "Model Parallelism only supports CUDA 11 on PyTorch 1.6 and PyTorch 1.7"
        )

    with timeout(minutes=DEFAULT_TIMEOUT):
        pytorch = PyTorch(entry_point='smmodelparallel_pt_mnist_multinode.sh',
                          role='SageMakerRole',
                          image_uri=ecr_image,
                          source_dir=mnist_path,
                          instance_count=2,
                          instance_type=instance_type,
                          sagemaker_session=sagemaker_session)

        pytorch.fit()
コード例 #2
0
def test_ecs_pytorch_training_dgl_gpu(gpu_only, py3_only,
                                      ecs_container_instance, pytorch_training,
                                      training_cmd, ecs_cluster_name):
    """
    GPU DGL test for PyTorch Training

    Instance Type - p3.8xlarge

    DGL is only supported in py3, hence we have used the "py3_only" fixture to ensure py2 images don't run
    on this function.

    Given above parameters, registers a task with family named after this test, runs the task, and waits for
    the task to be stopped before doing teardown operations of instance and cluster.
    """
    _, image_framework_version = get_framework_and_version_from_tag(
        pytorch_training)
    image_cuda_version = get_cuda_version_from_tag(pytorch_training)
    if Version(image_framework_version) == Version(
            "1.6") and image_cuda_version == "cu110":
        pytest.skip("DGL does not suport CUDA 11 for PyTorch 1.6")
    # TODO: Remove when DGL gpu test on ecs get fixed
    if Version(image_framework_version) >= Version(
            "1.10") and image_cuda_version == "cu113":
        pytest.skip("ecs test for DGL gpu fails since pt 1.10")

    instance_id, cluster_arn = ecs_container_instance

    num_gpus = ec2_utils.get_instance_num_gpus(instance_id)

    ecs_utils.ecs_training_test_executor(ecs_cluster_name,
                                         cluster_arn,
                                         training_cmd,
                                         pytorch_training,
                                         instance_id,
                                         num_gpus=num_gpus)
コード例 #3
0
def test_hf_smdp_multi(instance_types, ecr_image, py_version, sagemaker_session, tmpdir, framework_version):
    """
    Tests smddprun command via Estimator API distribution parameter
    """
    _, image_framework_version = get_framework_and_version_from_tag(ecr_image)
    image_cuda_version = get_cuda_version_from_tag(ecr_image)
    if Version(image_framework_version) < Version("2.3.1") or image_cuda_version != "cu110":
        pytest.skip("Data Parallelism is only supported on CUDA 11, and on TensorFlow 2.3.1 or higher")

    distribution = {"smdistributed": {"dataparallel": {"enabled": True}}}
    instance_type = "ml.p3.16xlarge"
    instance_count = 2

    estimator = HuggingFace(entry_point='train.py',
                           source_dir=BERT_PATH,
                           role='SageMakerRole',
                           instance_type=instance_type,
                           instance_count=instance_count,
                           image_uri=ecr_image,
                           framework_version=framework_version,
                           py_version=py_version,
                           sagemaker_session=sagemaker_session,
                           hyperparameters=hyperparameters,
                           distribution=distribution,
                           debugger_hook_config=False,  # currently needed
                           )

    estimator.fit(job_name=unique_name_from_base('test-tf-hf-smdp-multi'))
コード例 #4
0
def test_smmodelparallel_smdataparallel_mnist(instance_types, ecr_image,
                                              py_version, sagemaker_session,
                                              tmpdir):
    """
    Tests SM Distributed DataParallel and ModelParallel single-node via script mode
    This test has been added for SM DataParallelism and ModelParallelism tests for re:invent.
    TODO: Consider reworking these tests after re:Invent releases are done
    """
    _, image_framework_version = get_framework_and_version_from_tag(ecr_image)
    image_cuda_version = get_cuda_version_from_tag(ecr_image)
    if not (Version(image_framework_version)
            in SpecifierSet(">=1.6,<1.8")) or image_cuda_version != "cu110":
        pytest.skip(
            "Model Parallelism only supports CUDA 11 on PyTorch 1.6 and PyTorch 1.7"
        )
    with timeout(minutes=DEFAULT_TIMEOUT):
        pytorch = PyTorch(
            entry_point='smdataparallel_smmodelparallel_mnist_script_mode.sh',
            role='SageMakerRole',
            image_uri=ecr_image,
            source_dir=mnist_path,
            instance_count=1,
            instance_type=instance_types,
            sagemaker_session=sagemaker_session)

        pytorch = _disable_sm_profiler(sagemaker_session.boto_region_name,
                                       pytorch)

        pytorch.fit()
コード例 #5
0
def test_smmodelparallel_multinode(sagemaker_session, instance_type, ecr_image,
                                   tmpdir, framework_version, test_script):
    """
    Tests SM Modelparallel in sagemaker
    """
    instance_type = "ml.p3.16xlarge"
    _, image_framework_version = get_framework_and_version_from_tag(ecr_image)
    image_cuda_version = get_cuda_version_from_tag(ecr_image)
    if Version(image_framework_version) != Version(
            "2.3.1") or image_cuda_version != "cu110":
        pytest.skip(
            "Model Parallelism only supports CUDA 11 on TensorFlow 2.3")
    smmodelparallel_path = os.path.join(RESOURCE_PATH, 'smmodelparallel')
    estimator = TensorFlow(
        entry_point=test_script,
        role='SageMakerRole',
        instance_count=2,
        instance_type=instance_type,
        source_dir=smmodelparallel_path,
        distributions={
            "mpi": {
                "enabled":
                True,
                "processes_per_host":
                2,
                "custom_mpi_options":
                "-verbose --mca orte_base_help_aggregate 0 --mca btl_vader_single_copy_mechanism none ",
            }
        },
        sagemaker_session=sagemaker_session,
        image_uri=ecr_image,
        framework_version=framework_version,
        py_version='py3')
    estimator.fit()
コード例 #6
0
def test_hf_smdp(sagemaker_session, instance_type, ecr_image, tmpdir, framework_version):
    """
    Tests SMDataParallel single-node command via script mode
    """
    _, image_framework_version = get_framework_and_version_from_tag(ecr_image)
    image_cuda_version = get_cuda_version_from_tag(ecr_image)
    if Version(image_framework_version) < Version("2.3.1") or image_cuda_version != "cu110":
        pytest.skip("Data Parallelism is only supported on CUDA 11, and on TensorFlow 2.3.1 or higher")

    # configuration for running training on smdistributed Data Parallel
    distribution = {'smdistributed': {'dataparallel': {'enabled': True}}}

    instance_type = "ml.p3.16xlarge"
    instance_count = 1

    estimator = HuggingFace(
        entry_point='train.py',
        source_dir=BERT_PATH,
        role='SageMakerRole',
        instance_type=instance_type,
        instance_count=instance_count,
        image_uri=ecr_image,
        framework_version=framework_version,
        py_version=py_version,
        distribution=distribution,
        sagemaker_session=sagemaker_session,
        hyperparameters=hyperparameters,
        debugger_hook_config=False,  # currently needed
    )

    estimator.fit(job_name=unique_name_from_base('test-tf-hf-smdp'))
コード例 #7
0
def test_distributed_training_smdataparallel_script_mode(
        sagemaker_session, instance_type, ecr_image, tmpdir,
        framework_version):
    """
    Tests SMDataParallel single-node command via script mode
    """
    _, image_framework_version = get_framework_and_version_from_tag(ecr_image)
    image_cuda_version = get_cuda_version_from_tag(ecr_image)
    if Version(image_framework_version) < Version(
            "2.3.1") or image_cuda_version != "cu110":
        pytest.skip(
            "Data Parallelism is only supported on CUDA 11, and on TensorFlow 2.3.1 or higher"
        )
    instance_type = "ml.p3.16xlarge"
    estimator = TensorFlow(entry_point='smdataparallel_mnist_script_mode.sh',
                           source_dir=MNIST_PATH,
                           role='SageMakerRole',
                           instance_type=instance_type,
                           instance_count=1,
                           image_uri=ecr_image,
                           framework_version=framework_version,
                           py_version='py3',
                           sagemaker_session=sagemaker_session)

    estimator.fit(job_name=unique_name_from_base('test-tf-smdataparallel'))
コード例 #8
0
def test_smdataparallel_smmodelparallel_mnist(n_virginia_sagemaker_session,
                                              instance_type,
                                              n_virginia_ecr_image, tmpdir,
                                              framework_version):
    """
    Tests SM Distributed DataParallel and ModelParallel single-node via script mode
    This test has been added for SM DataParallelism and ModelParallelism tests for re:invent.
    TODO: Consider reworking these tests after re:Invent releases are done
    """
    instance_type = "ml.p3.16xlarge"
    _, image_framework_version = get_framework_and_version_from_tag(
        n_virginia_ecr_image)
    image_cuda_version = get_cuda_version_from_tag(n_virginia_ecr_image)
    if Version(image_framework_version) < Version(
            "2.3.1") or image_cuda_version != "cu110":
        pytest.skip(
            "SMD Model and Data Parallelism are only supported on CUDA 11, and on TensorFlow 2.3.1 or higher"
        )
    smmodelparallel_path = os.path.join(RESOURCE_PATH, 'smmodelparallel')
    test_script = "smdataparallel_smmodelparallel_mnist_script_mode.sh"
    estimator = TensorFlow(entry_point=test_script,
                           role='SageMakerRole',
                           instance_count=1,
                           instance_type=instance_type,
                           source_dir=smmodelparallel_path,
                           sagemaker_session=n_virginia_sagemaker_session,
                           image_uri=n_virginia_ecr_image,
                           framework_version=framework_version,
                           py_version='py3')

    estimator = _disable_sm_profiler(
        n_virginia_sagemaker_session.boto_region_name, estimator)

    estimator.fit()
コード例 #9
0
def _test_smdataparallel_smmodelparallel_mnist_function(
        ecr_image, sagemaker_session, instance_type, tmpdir,
        framework_version):
    instance_type = "ml.p3.16xlarge"
    _, image_framework_version = get_framework_and_version_from_tag(ecr_image)
    image_cuda_version = get_cuda_version_from_tag(ecr_image)
    if Version(image_framework_version) < Version(
            "2.3.1") or image_cuda_version != "cu110":
        pytest.skip(
            "SMD Model and Data Parallelism are only supported on CUDA 11, and on TensorFlow 2.3.1 or higher"
        )
    smmodelparallel_path = os.path.join(RESOURCE_PATH, 'smmodelparallel')
    test_script = "smdataparallel_smmodelparallel_mnist_script_mode.sh"
    estimator = TensorFlow(entry_point=test_script,
                           role='SageMakerRole',
                           instance_count=1,
                           instance_type=instance_type,
                           source_dir=smmodelparallel_path,
                           sagemaker_session=sagemaker_session,
                           image_uri=ecr_image,
                           framework_version=framework_version,
                           py_version='py3')

    estimator = _disable_sm_profiler(sagemaker_session.boto_region_name,
                                     estimator)

    estimator.fit()
コード例 #10
0
def test_smclarify_metrics_gpu(training, ec2_connection, region,
                               ec2_instance_type, gpu_only, py3_only):
    image_cuda_version = get_cuda_version_from_tag(training)
    if image_cuda_version != "cu110":
        pytest.skip("SmClarify is currently installed in cuda 11 gpu images")
    run_smclarify_bias_metrics(training, ec2_connection, region,
                               ec2_instance_type)
コード例 #11
0
def _test_hf_smdp_function(ecr_image, sagemaker_session, instance_type,
                           framework_version, py_version, tmpdir,
                           instance_count):
    _, image_framework_version = get_framework_and_version_from_tag(ecr_image)
    image_cuda_version = get_cuda_version_from_tag(ecr_image)

    instance_type = "ml.p3.16xlarge"
    distribution = {"smdistributed": {"dataparallel": {"enabled": True}}}

    estimator = HuggingFace(
        entry_point='train.py',
        source_dir=BERT_PATH,
        role='SageMakerRole',
        instance_type=instance_type,
        instance_count=instance_count,
        image_uri=ecr_image,
        framework_version=framework_version,
        py_version=py_version,
        sagemaker_session=sagemaker_session,
        hyperparameters=hyperparameters,
        distribution=distribution,
        debugger_hook_config=False,  # currently needed
    )

    estimator.fit(job_name=unique_name_from_base("test-tf-hf-smdp-multi"))
コード例 #12
0
def test_dlc_standard_labels(image, region):
    customer_type_label_prefix = "ec2" if test_utils.is_ec2_image(
        image) else "sagemaker"

    framework, fw_version = test_utils.get_framework_and_version_from_tag(
        image)
    framework = framework.replace('_', '-')
    fw_version = fw_version.replace('.', '-')
    device_type = test_utils.get_processor_from_image_uri(image)
    if device_type == "gpu":
        cuda_verison = test_utils.get_cuda_version_from_tag(image)
        device_type = f"{device_type}.{cuda_verison}"
    python_version = test_utils.get_python_version_from_image_uri(image)
    job_type = test_utils.get_job_type_from_image(image)
    transformers_version = test_utils.get_transformers_version_from_image_uri(
        image).replace('.', '-')
    os_version = test_utils.get_os_version_from_image_uri(image).replace(
        '.', '-')

    # TODO: Add x86 env variable to check explicitly for x86, instead of assuming that everything not graviton is x86
    arch_type = "graviton" if test_utils.is_graviton_architecture() else "x86"

    contributor = test_utils.get_contributor_from_image_uri(image)

    expected_labels = [
        f"com.amazonaws.ml.engines.{customer_type_label_prefix}.dlc.framework.{framework}.{fw_version}",
        f"com.amazonaws.ml.engines.{customer_type_label_prefix}.dlc.device.{device_type}",
        f"com.amazonaws.ml.engines.{customer_type_label_prefix}.dlc.python.{python_version}",
        f"com.amazonaws.ml.engines.{customer_type_label_prefix}.dlc.job.{job_type}",
        f"com.amazonaws.ml.engines.{customer_type_label_prefix}.dlc.arch.{arch_type}",
        f"com.amazonaws.ml.engines.{customer_type_label_prefix}.dlc.os.{os_version}",
    ]

    if contributor:
        expected_labels.append(
            f"com.amazonaws.ml.engines.{customer_type_label_prefix}.dlc.contributor.{contributor}"
        )
    if transformers_version:
        expected_labels.append(
            f"com.amazonaws.ml.engines.{customer_type_label_prefix}.dlc.lib.transformers.{transformers_version}"
        )

    actual_labels = test_utils.get_labels_from_ecr_image(image, region)

    missing_labels = []

    for label in expected_labels:
        if label not in actual_labels:
            missing_labels.append(label)

    # TODO: Remove this when ec2 labels are added. For now, ensure they are not added.
    if customer_type_label_prefix == "ec2":
        assert set(missing_labels) == set(expected_labels), \
            f"EC2 labels are not supported yet, and should not be added to containers. " \
            f"{set(expected_labels) - set(missing_labels)} should not be present."
    else:
        assert not missing_labels, \
            f"Labels {missing_labels} are expected in image {image}, but cannot be found. " \
            f"All labels on image: {actual_labels}"
コード例 #13
0
def test_pytorch_train_dgl_gpu(pytorch_training, ec2_connection, gpu_only,
                               py3_only):
    _, image_framework_version = get_framework_and_version_from_tag(
        pytorch_training)
    image_cuda_version = get_cuda_version_from_tag(pytorch_training)
    if Version(image_framework_version) == Version(
            "1.6") and image_cuda_version == "cu110":
        pytest.skip("DGL does not suport CUDA 11 for PyTorch 1.6")
    execute_ec2_training_test(ec2_connection, pytorch_training, PT_DGL_CMD)
コード例 #14
0
def test_pytorch_train_dgl_gpu(pytorch_training, ec2_connection, gpu_only, py3_only, ec2_instance_type):
    _, image_framework_version = get_framework_and_version_from_tag(pytorch_training)
    image_cuda_version = get_cuda_version_from_tag(pytorch_training)
    # TODO: Remove when DGL gpu test on ecs get fixed
    if Version(image_framework_version) >= Version("1.10") and image_cuda_version == "cu113":
        pytest.skip("ecs test for DGL gpu fails since pt 1.10")
    if test_utils.is_image_incompatible_with_instance_type(pytorch_training, ec2_instance_type):
        pytest.skip(f"Image {pytorch_training} is incompatible with instance type {ec2_instance_type}")
    execute_ec2_training_test(ec2_connection, pytorch_training, PT_DGL_CMD)
コード例 #15
0
def test_dgl_gcn_training_gpu(sagemaker_session, ecr_image, instance_type):

    _, image_framework_version = get_framework_and_version_from_tag(ecr_image)
    image_cuda_version = get_cuda_version_from_tag(ecr_image)
    if Version(image_framework_version) == Version(
            "1.6") and image_cuda_version == "cu110":
        pytest.skip("DGL does not support CUDA 11 for PyTorch 1.6")

    instance_type = instance_type or 'ml.p2.xlarge'
    _test_dgl_training(sagemaker_session, ecr_image, instance_type)
コード例 #16
0
def test_dgl_gcn_training_gpu(ecr_image, sagemaker_regions, instance_type):

    _, image_framework_version = get_framework_and_version_from_tag(ecr_image)
    image_cuda_version = get_cuda_version_from_tag(ecr_image)
    if Version(image_framework_version) == Version(
            "1.6") and image_cuda_version == "cu110":
        pytest.skip("DGL does not support CUDA 11 for PyTorch 1.6")

    instance_type = instance_type or 'ml.p2.xlarge'
    function_args = {
        'instance_type': instance_type,
    }
    invoke_pytorch_helper_function(ecr_image, sagemaker_regions,
                                   _test_dgl_training, function_args)
コード例 #17
0
def test_smclarify_metrics_gpu(
    training,
    ec2_connection,
    ec2_instance_type,
    gpu_only,
    py3_only,
    tf23_and_above_only,
    mx18_and_above_only,
    pt16_and_above_only,
):
    image_cuda_version = get_cuda_version_from_tag(training)
    if Version(image_cuda_version.strip("cu")) < Version("110"):
        pytest.skip("SmClarify is currently installed in cuda 11 gpu images and above")
    run_smclarify_bias_metrics(training, ec2_connection, ec2_instance_type, docker_executable="nvidia-docker")
コード例 #18
0
def test_pytorch_train_dgl_gpu(pytorch_training, ec2_connection, gpu_only,
                               py3_only, ec2_instance_type):
    _, image_framework_version = get_framework_and_version_from_tag(
        pytorch_training)
    image_cuda_version = get_cuda_version_from_tag(pytorch_training)
    # TODO: Remove when DGL with cuda 11.3 support is released
    if Version(image_framework_version) == Version(
            "1.10") and image_cuda_version == "cu113":
        pytest.skip("DGL CUDA 11.3 was not introduced in PyTorch 1.10")
    if test_utils.is_image_incompatible_with_instance_type(
            pytorch_training, ec2_instance_type):
        pytest.skip(
            f"Image {pytorch_training} is incompatible with instance type {ec2_instance_type}"
        )
    execute_ec2_training_test(ec2_connection, pytorch_training, PT_DGL_CMD)
コード例 #19
0
def test_dgl_gcn_training_gpu(ecr_image, sagemaker_regions, instance_type):
    _, image_framework_version = get_framework_and_version_from_tag(ecr_image)
    image_cuda_version = get_cuda_version_from_tag(ecr_image)
    
    # TODO: Remove when DGL gpu test on ecs get fixed
    if Version(image_framework_version) >= Version("1.10") and image_cuda_version == "cu113":
        pytest.skip("ecs test for DGL gpu fails since pt 1.10")

    if Version(image_framework_version) == Version("1.6") and image_cuda_version == "cu110":
        pytest.skip("DGL does not support CUDA 11 for PyTorch 1.6")

    instance_type = instance_type or "ml.p2.xlarge"
    function_args = {
        "instance_type": instance_type,
    }
    invoke_pytorch_helper_function(ecr_image, sagemaker_regions, _test_dgl_training, function_args)
コード例 #20
0
def test_dgl_gcn_training_gpu(ecr_image, sagemaker_regions, instance_type):
    # TODO: Remove condition when DGL is added back to PT 1.10
    _, image_framework_version = get_framework_and_version_from_tag(ecr_image)
    image_cuda_version = get_cuda_version_from_tag(ecr_image)
    if Version(image_framework_version) == Version(
            "1.10") and image_cuda_version == "cu113":
        pytest.skip("DGL CUDA 11.3 was not introduced in PyTorch 1.10")
    if Version(image_framework_version) == Version(
            "1.6") and image_cuda_version == "cu110":
        pytest.skip("DGL does not support CUDA 11 for PyTorch 1.6")

    instance_type = instance_type or "ml.p2.xlarge"
    function_args = {
        "instance_type": instance_type,
    }
    invoke_pytorch_helper_function(ecr_image, sagemaker_regions,
                                   _test_dgl_training, function_args)
コード例 #21
0
def test_smdataparallel_mnist(instance_types, ecr_image, py_version,
                              sagemaker_session, tmpdir):
    """
    Tests smddprun command via Estimator API distribution parameter
    """
    _, image_framework_version = get_framework_and_version_from_tag(ecr_image)
    image_cuda_version = get_cuda_version_from_tag(ecr_image)
    if Version(image_framework_version) < Version(
            "2.3.1") or image_cuda_version != "cu110":
        pytest.skip(
            "Data Parallelism is only supported on CUDA 11, and on TensorFlow 2.3.1 or higher"
        )
    distribution = {"smdistributed": {"dataparallel": {"enabled": True}}}
    estimator = TensorFlow(entry_point='smdataparallel_mnist.py',
                           role='SageMakerRole',
                           image_uri=ecr_image,
                           source_dir=MNIST_PATH,
                           instance_count=2,
                           instance_type=instance_types,
                           sagemaker_session=sagemaker_session,
                           distribution=distribution)

    estimator.fit(
        job_name=unique_name_from_base('test-tf-smdataparallel-multi'))
コード例 #22
0
def can_run_smdataparallel(ecr_image):
    _, image_framework_version = get_framework_and_version_from_tag(ecr_image)
    image_cuda_version = get_cuda_version_from_tag(ecr_image)
    return Version(image_framework_version) in SpecifierSet(
        ">=1.6") and Version(image_cuda_version.strip("cu")) >= Version("110")
コード例 #23
0
def test_eks_pytorch_dgl_single_node_training(pytorch_training, py3_only):
    """
    Function to create a pod using kubectl and given container image, and run
    DGL training with PyTorch backend
    Args:
        :param pytorch_training: the ECR URI
    """
    _, image_framework_version = get_framework_and_version_from_tag(
        pytorch_training)
    image_cuda_version = get_cuda_version_from_tag(pytorch_training)
    if Version(image_framework_version) == Version(
            "1.6") and image_cuda_version == "cu110":
        pytest.skip("DGL does not suport CUDA 11 for PyTorch 1.6")
    # TODO: Remove when DGL gpu test on ecs get fixed
    if Version(image_framework_version) >= Version("1.10"):
        pytest.skip("ecs test for DGL gpu fails since pt 1.10")

    training_result = False
    rand_int = random.randint(4001, 6000)

    yaml_path = os.path.join(
        os.sep, "tmp", f"pytorch_single_node_training_dgl_{rand_int}.yaml")
    pod_name = f"pytorch-single-node-training-dgl-{rand_int}"

    if is_below_framework_version("1.7", pytorch_training, "pytorch"):
        dgl_branch = "0.4.x"
    else:
        dgl_branch = "0.7.x"

    args = (
        f"git clone -b {dgl_branch} https://github.com/dmlc/dgl.git && "
        f"cd /dgl/examples/pytorch/gcn/ && DGLBACKEND=pytorch python train.py --dataset cora"
    )

    # TODO: Change hardcoded value to read a mapping from the EKS cluster instance.
    cpu_limit = 72
    cpu_limit = str(int(cpu_limit) / 2)

    if "gpu" in pytorch_training:
        args = args + " --gpu 0"
    else:
        args = args + " --gpu -1"

    search_replace_dict = {
        "<POD_NAME>": pod_name,
        "<CONTAINER_NAME>": pytorch_training,
        "<ARGS>": args,
        "<CPU_LIMIT>": cpu_limit,
    }

    eks_utils.write_eks_yaml_file_from_template(
        eks_utils.SINGLE_NODE_TRAINING_TEMPLATE_PATH, yaml_path,
        search_replace_dict)

    try:
        run("kubectl create -f {}".format(yaml_path))

        if eks_utils.is_eks_training_complete(pod_name):
            dgl_out = run("kubectl logs {}".format(pod_name)).stdout
            if "Test accuracy" in dgl_out:
                training_result = True
            else:
                eks_utils.LOGGER.info("**** training output ****")
                eks_utils.LOGGER.debug(dgl_out)

        assert training_result, f"Training failed"
    finally:
        run("kubectl delete pods {}".format(pod_name))
コード例 #24
0
def can_run_smmodelparallel_efa(ecr_image):
    _, image_framework_version = get_framework_and_version_from_tag(ecr_image)
    image_cuda_version = get_cuda_version_from_tag(ecr_image)
    return Version(
        image_framework_version) in SpecifierSet(">=2.4.1,<2.7.0") and Version(
            image_cuda_version.strip("cu")) >= Version("110")