Exemplo n.º 1
0
def test_amp_coverage(lp_dtype, lp_name):
    conditional = [item[0] for item in amp.list_conditional_fp32_ops(lp_dtype)]
    lp16_ops = amp.list_lp16_ops(lp_dtype)
    lp16_fp32_ops = amp.list_lp16_fp32_ops(lp_dtype)
    fp32_ops = amp.list_fp32_ops(lp_dtype)
    widest_ops = amp.list_widest_type_cast(lp_dtype)
    all_lp_lists = [lp16_ops, lp16_fp32_ops, fp32_ops, widest_ops, conditional]

    # Check for duplicates
    for op_list in all_lp_lists:
        ret = [
            op for op, count in collections.Counter(op_list).items()
            if count > 1
        ]
        assert ret == [], "Elements " + str(
            ret) + " are duplicated in the AMP lists."

    all_lp_ops = [op for op_list in all_lp_lists for op in op_list]
    ret = [
        op for op, count in collections.Counter(all_lp_ops).items()
        if count > 1
    ]
    assert ret == [], "Elements " + str(
        ret) + " exist in more than 1 AMP list."

    # Check the coverage
    covered_ops = set(all_lp_ops)
    all_mxnet_ops = get_all_registered_operators_grouped()
    required_ops = {op for op in all_mxnet_ops if not "backward" in op}

    extra_ops = covered_ops - required_ops
    assert not extra_ops, f"{len(extra_ops)} operators are not needed in the AMP lists: {sorted(extra_ops)}"

    guidelines = f"""Please follow these guidelines for choosing a proper list:
    - if your operator is not to be used in a computational graph
      (e.g. image manipulation operators, optimizers) or does not have
      inputs, put it in {lp_name.upper()}_FP32_FUNCS list,
    - if your operator requires FP32 inputs or is not safe to use with lower
      precision, put it in FP32_FUNCS list,
    - if your operator supports both FP32 and lower precision, has
      multiple inputs and expects all inputs to be of the same
      type, put it in WIDEST_TYPE_CASTS list,
    - if your operator supports both FP32 and lower precision and has
      either a single input or supports inputs of different type,
      put it in {lp_name.upper()}_FP32_FUNCS list,
    - if your operator is both safe to use in lower precision and
      it is highly beneficial to use it in lower precision, then
      put it in {lp_name.upper()}_FUNCS (this is unlikely for new operators)
    - If you are not sure which list to choose, FP32_FUNCS is the
      safest option"""
    missing_ops = required_ops - covered_ops

    if len(missing_ops) > 0:
        warnings.warn(
            f"{len(missing_ops)} operators {sorted(missing_ops)} do not exist in AMP lists "
            f"(in python/mxnet/amp/lists/symbol_{lp_name.lower()}.py) - please add them. \n{guidelines}"
        )
Exemplo n.º 2
0
def test_amp_coverage():
    conditional = [item[0] for item in amp.lists.symbol_bf16.CONDITIONAL_FP32_FUNCS]

    # Check for duplicates
    for a in [amp.lists.symbol_bf16.BF16_FUNCS,
              amp.lists.symbol_bf16.BF16_FP32_FUNCS,
              amp.lists.symbol_bf16.FP32_FUNCS,
              amp.lists.symbol_bf16.WIDEST_TYPE_CASTS,
              conditional]:
        ret = [item for item, count in collections.Counter(a).items() if count > 1]
        assert ret == [], "Elements " + str(ret) + " are duplicated in the AMP lists."

    t = []
    for a in [amp.lists.symbol_bf16.BF16_FUNCS,
              amp.lists.symbol_bf16.BF16_FP32_FUNCS,
              amp.lists.symbol_bf16.FP32_FUNCS,
              amp.lists.symbol_bf16.WIDEST_TYPE_CASTS,
              conditional]:
        t += a
    ret = [item for item, count in collections.Counter(t).items() if count > 1]
    assert ret == [], "Elements " + str(ret) + " exist in more than 1 AMP list."

    # Check the coverage
    covered = set(t)
    ops = get_all_registered_operators_grouped()
    required = set(k for k in ops
                   if not k.startswith(("_backward", "_contrib_backward", "_npi_backward")) and
                   not k.endswith("_backward"))

    extra = covered - required
    assert not extra, f"{len(extra)} operators are not needed in the AMP lists: {sorted(extra)}"

    guidelines = """Please follow these guidelines for choosing a proper list:
    - if your operator is not to be used in a computational graph
      (e.g. image manipulation operators, optimizers) or does not have
      inputs, put it in BF16_FP32_FUNCS list,
    - if your operator requires FP32 inputs or is not safe to use with lower
      precision, put it in FP32_FUNCS list,
    - if your operator supports both FP32 and lower precision, has
      multiple inputs and expects all inputs to be of the same
      type, put it in WIDEST_TYPE_CASTS list,
    - if your operator supports both FP32 and lower precision and has
      either a single input or supports inputs of different type,
      put it in BF16_FP32_FUNCS list,
    - if your operator is both safe to use in lower precision and
      it is highly beneficial to use it in lower precision, then
      put it in BF16_FUNCS (this is unlikely for new operators)
    - If you are not sure which list to choose, FP32_FUNCS is the
      safest option"""
    diff = required - covered

    if len(diff) > 0:
      warnings.warn(f"{len(diff)} operators {sorted(diff)} do not exist in AMP lists (in "
                    f"python/mxnet/amp/lists/symbol_bf16.py) - please add them. "
                    f"\n{guidelines}")