def test_validate_smdataparallel_args_not_raises(): smdataparallel_enabled = {"smdistributed": {"dataparallel": {"enabled": True}}} smdataparallel_disabled = {"smdistributed": {"dataparallel": {"enabled": False}}} # Cases {PT|TF2} # 1. SM Distributed dataparallel disabled # 2. SM Distributed dataparallel enabled with supported args good_args = [ (None, None, None, None, smdataparallel_disabled), ("ml.p3.16xlarge", "tensorflow", "2.3.1", "py3", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.3.2", "py3", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.3", "py3", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "2.4", "py3", smdataparallel_enabled), ("ml.p3.16xlarge", "pytorch", "1.6.0", "py3", smdataparallel_enabled), ("ml.p3.16xlarge", "pytorch", "1.6", "py3", smdataparallel_enabled), ("ml.p3.16xlarge", "pytorch", "1.7.1", "py3", smdataparallel_enabled), ("ml.p3.16xlarge", "pytorch", "1.7", "py3", smdataparallel_enabled), ("ml.p3.16xlarge", "pytorch", "1.8.0", "py3", smdataparallel_enabled), ("ml.p3.16xlarge", "pytorch", "1.8", "py3", smdataparallel_enabled), ] for instance_type, framework_name, framework_version, py_version, distribution in good_args: fw_utils._validate_smdataparallel_args( instance_type, framework_name, framework_version, py_version, distribution )
def test_validate_smdataparallel_args_raises(): # TODO: add validation for dataparallel in mxnet smdataparallel_enabled = {"smdistributed": {"dataparallel": {"enabled": True}}} # Cases {PT|TF2} # 1. None instance type # 2. incorrect instance type # 3. incorrect python version # 4. incorrect framework version bad_args = [ (None, "tensorflow", "2.3.1", "py3", smdataparallel_enabled), ("ml.p3.2xlarge", "tensorflow", "2.3.1", "py3", smdataparallel_enabled), ("ml.p3dn.24xlarge", "tensorflow", "2.3.1", "py2", smdataparallel_enabled), ("ml.p3.16xlarge", "tensorflow", "1.3.1", "py3", smdataparallel_enabled), (None, "pytorch", "1.6.0", "py3", smdataparallel_enabled), ("ml.p3.2xlarge", "pytorch", "1.6.0", "py3", smdataparallel_enabled), ("ml.p3dn.24xlarge", "pytorch", "1.6.0", "py2", smdataparallel_enabled), ("ml.p3.16xlarge", "pytorch", "1.5.0", "py3", smdataparallel_enabled), ] for instance_type, framework_name, framework_version, py_version, distribution in bad_args: with pytest.raises(ValueError): fw_utils._validate_smdataparallel_args( instance_type, framework_name, framework_version, py_version, distribution )