def test_validate_smdataparallel_args_not_raises():
    smdataparallel_enabled = {"smdistributed": {"dataparallel": {"enabled": True}}}
    smdataparallel_disabled = {"smdistributed": {"dataparallel": {"enabled": False}}}

    # Cases {PT|TF2}
    # 1. SM Distributed dataparallel disabled
    # 2. SM Distributed dataparallel enabled with supported args

    good_args = [
        (None, None, None, None, smdataparallel_disabled),
        ("ml.p3.16xlarge", "tensorflow", "2.3.1", "py3", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.3.2", "py3", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.3", "py3", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "2.4", "py3", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.6.0", "py3", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.6", "py3", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.7.1", "py3", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.7", "py3", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.8.0", "py3", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.8", "py3", smdataparallel_enabled),
    ]
    for instance_type, framework_name, framework_version, py_version, distribution in good_args:
        fw_utils._validate_smdataparallel_args(
            instance_type, framework_name, framework_version, py_version, distribution
        )
def test_validate_smdataparallel_args_raises():
    # TODO: add validation for dataparallel in mxnet
    smdataparallel_enabled = {"smdistributed": {"dataparallel": {"enabled": True}}}

    # Cases {PT|TF2}
    # 1. None instance type
    # 2. incorrect instance type
    # 3. incorrect python version
    # 4. incorrect framework version

    bad_args = [
        (None, "tensorflow", "2.3.1", "py3", smdataparallel_enabled),
        ("ml.p3.2xlarge", "tensorflow", "2.3.1", "py3", smdataparallel_enabled),
        ("ml.p3dn.24xlarge", "tensorflow", "2.3.1", "py2", smdataparallel_enabled),
        ("ml.p3.16xlarge", "tensorflow", "1.3.1", "py3", smdataparallel_enabled),
        (None, "pytorch", "1.6.0", "py3", smdataparallel_enabled),
        ("ml.p3.2xlarge", "pytorch", "1.6.0", "py3", smdataparallel_enabled),
        ("ml.p3dn.24xlarge", "pytorch", "1.6.0", "py2", smdataparallel_enabled),
        ("ml.p3.16xlarge", "pytorch", "1.5.0", "py3", smdataparallel_enabled),
    ]
    for instance_type, framework_name, framework_version, py_version, distribution in bad_args:
        with pytest.raises(ValueError):
            fw_utils._validate_smdataparallel_args(
                instance_type, framework_name, framework_version, py_version, distribution
            )