def test_cpu_amp_precision_context_manager(tmpdir):
    """Test to ensure that the context manager correctly is set to CPU + bfloat16."""
    plugin = NativeMixedPrecisionPlugin("bf16", "cpu")
    assert plugin.device == "cpu"
    assert plugin.scaler is None
    context_manager = plugin.autocast_context_manager()
    assert isinstance(context_manager, torch.autocast)
    # check with str due to a bug upstream: https://github.com/pytorch/pytorch/issues/65786
    assert str(context_manager.fast_dtype) == str(torch.bfloat16)
Esempio n. 2
0
def test_cpu_amp_precision_context_manager(tmpdir):
    """Test to ensure that the context manager correctly is set to CPU + bfloat16, and a scaler isn't set."""

    plugin = NativeMixedPrecisionPlugin(precision="bf16", use_cpu=True)
    assert plugin.use_cpu
    assert not hasattr(plugin, "scaler")
    context_manager = plugin.autocast_context_manager()
    assert isinstance(context_manager, torch.cpu.amp.autocast)
    assert context_manager.dtype == torch.bfloat16