def test_norm(device, norm_type, mixed_precision):
    """Test checkpoint_wrapper with different norm layers."""
    if device == "cuda" and not torch.cuda.is_available():
        pytest.skip("Skip due to lack of GPU")

    # Get input, ref, checkpoint models and make them equal.
    in_data = torch.rand(2, 2, 3, 3).to(device)
    m_ref = get_model(norm_type, False, mixed_precision).to(device)
    m_cpt = get_model(norm_type, True, mixed_precision).to(device)
    m_cpt.load_state_dict(m_ref.state_dict())

    if torch_version() >= (1, 6, 0):
        # This assert fails on 1.5.1.
        assert objects_are_equal(m_ref.state_dict(), m_cpt.state_dict())

    if mixed_precision != "fp32":
        in_data = in_data.half()

    # Needed due to checkpointing.
    in_data.requires_grad = True
    for model in (m_ref, m_cpt):
        optim = SGD(model.parameters(), lr=0.1)
        if device == "cpu" and mixed_precision != "fp32":
            # Got: RuntimeError: "batch_norm"/"layer_norm" not implemented for 'Half'.
            with pytest.raises(RuntimeError):
                out = model(in_data)
            return
        else:
            # Everything else work.
            out = model(in_data)
        out.sum().backward()
        optim.step()

    if torch_version() >= (1, 6, 0):
        assert objects_are_equal(m_ref.state_dict(), m_cpt.state_dict())
def rpc_worker(rank, world_size, init_file, func, *args):
    if torch_version() == (1, 8, 0):
        if torch.cuda.is_available():
            # Workaround for https://github.com/pytorch/pytorch/issues/53844
            options = rpc.TensorPipeRpcBackendOptions(
                init_method="file://" + init_file, _transports=["ibv", "uv"])
        else:
            # Workaround for https://github.com/pytorch/pytorch/issues/54266
            options = rpc.TensorPipeRpcBackendOptions(
                init_method="file://" + init_file,
                _channels=[
                    "mpt_uv", "basic", "cuda_ipc", "cuda_gdr", "cuda_xth",
                    "cuda_basic"
                ],
            )
    else:
        options = rpc.TensorPipeRpcBackendOptions(init_method="file://" +
                                                  init_file)
    rpc.init_rpc(
        "worker" + str(rank),
        rank=rank,
        world_size=world_size,
        backend=rpc.BackendType.TENSORPIPE,
        rpc_backend_options=options,
    )
    if rank == 0:
        func(*args)
    rpc.shutdown()
Exemple #3
0
def test_smaller_than_world_size(world_size, test_case, fsdp_config):
    """Test FSDP with uneven divide of parameter shards."""
    if torch_version() < (1, 6, 0):
        pytest.skip(
            "older pytorch doesn't support reduce_scatter in gloo backend")

    if world_size > torch.cuda.device_count():
        pytest.skip("Not enough GPUs.")

    temp_file_name = tempfile.mkstemp()[1]
    unused = tempfile.mkstemp()[1]

    model = Sequential(
        Linear(3, 3, bias=False),
        Linear(3, 4, bias=False),
        Linear(4, 5, bias=False),
        Linear(5, 4, bias=False),
        Linear(4, 3, bias=False),
        Linear(3, 1, bias=False),
        Linear(1, 1, bias=False
               ),  # param here is smaller than world_size if unflattened.
    )
    mp.spawn(
        _test_func,
        args=(world_size, model, fsdp_config, temp_file_name, unused,
              test_case),
        nprocs=world_size,
        join=True,
    )
Exemple #4
0
 def setUp(self):
     if torch_version() < (1, 6, 0):
         raise unittest.SkipTest("Need pytorch version >= 1.6 due to lack of reduce_scatter")
     if not torch.cuda.is_available():
         raise unittest.SkipTest("CUDA not available, skipping test")
     if sys.platform == "win32":
         raise unittest.SkipTest("NCCL doesn't support Windows, skipping test")
     if torch.cuda.device_count() < 2:
         raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping")
Exemple #5
0
def test_regnet(temp_files, ddp_ref, precision, flatten, sync_bn):
    if torch_version() < (1, 6, 0):
        pytest.skip("older pytorch doesn't support reduce_scatter")

    state_before, inputs, conv_bias, linear_bias, state_after = ddp_ref

    state_after = state_after[(precision, sync_bn)]

    fsdp_config = {}
    fsdp_config["mixed_precision"] = precision == "mixed"
    fsdp_config["flatten_parameters"] = flatten == "flatten"

    # When linear bias is True, DDP's AMP O1 and FSDP's default AMP O1.5 is different,
    # we force FSDP to use AMP O1 here by setting compute_dtype to float32.
    if linear_bias:
        fsdp_config["compute_dtype"] = torch.float32

    if fsdp_config["mixed_precision"] and torch_cuda_version() < (11, 0):
        pytest.skip("Only CUDA 11 is supported with AMP equivalency")

    # Wrap BN half of the time.
    wrap_bn = True
    if random.randint(0, 1) == 0:
        wrap_bn = False
    # Except, always wrap BN in mixed precision + sync_bn mode, due to error of sync_bn wrapping,
    # regardless of compute_dtype.
    if fsdp_config["mixed_precision"] and sync_bn != "none":
        wrap_bn = True

    # When BN is not wrapped (i.e. not in full precision), FSDP's compute_dtype needs to
    # be fp32 to match DDP (otherwise, numerical errors happen on BN's running_mean/running_var
    # buffers).
    if fsdp_config["mixed_precision"] and not wrap_bn:
        fsdp_config["compute_dtype"] = torch.float32

    world_size = _world_size
    mp.spawn(
        _distributed_worker,
        args=(
            world_size,
            fsdp_config,
            wrap_bn,
            None,
            temp_files[0],
            temp_files[1],
            state_before,
            inputs,
            None,
            state_after,
            sync_bn,
            conv_bias,
            linear_bias,
        ),
        nprocs=world_size,
        join=True,
    )
def test_basic(device):
    if "cuda" in device and not torch.cuda.is_available():
        pytest.skip("test requires a GPU")

    input = torch.rand(2, 16, 32).requires_grad_(True)
    model = BasicModel().to(device)
    no_cpt = get_loss_and_gnorm(model, input.to(device))

    model = BasicModel(use_pytorch_checkpoint=True).to(device)
    pyt_cpt = get_loss_and_gnorm(model, input.to(device))

    model = BasicModel(use_fairscale_checkpoint=True).to(device)
    fairscale_cpt = get_loss_and_gnorm(model, input.to(device))

    model = BasicModel(use_fairscale_checkpoint=True, offload_to_cpu=True).to(device)
    fairscale_cpt_offload = get_loss_and_gnorm(model, input.to(device))

    # Check for correctness.
    for key in "loss", "gnorm":
        if not (no_cpt[key] == pyt_cpt[key] == fairscale_cpt[key] == fairscale_cpt_offload[key]):
            print(no_cpt, pyt_cpt, fairscale_cpt, fairscale_cpt_offload)
            assert 0
        del no_cpt[key]
        del pyt_cpt[key]
        del fairscale_cpt[key]
        del fairscale_cpt_offload[key]

    # Check for memory usage for cuda only.
    if "cpu" in device:
        return

    mem_peaks = [98816, 103424, 103424, 107520]
    if torch_version() < (1, 7, 0):
        # Older torch behaves slightly differently
        mem_peaks = [102400, 103424, 103424, 107520]

    assert no_cpt == {"mem_0": 38912, "mem_peak": mem_peaks[0], "mem_after_fwd": 64000, "mem_after_bwd": 74240}, no_cpt
    assert pyt_cpt == {
        "mem_0": 38912,
        "mem_peak": mem_peaks[1],
        "mem_after_fwd": 43520,
        "mem_after_bwd": 74240,
    }, pyt_cpt
    assert fairscale_cpt == {
        "mem_0": 38912,
        "mem_peak": mem_peaks[2],
        "mem_after_fwd": 43520,
        "mem_after_bwd": 74240,
    }, fairscale_cpt
    assert fairscale_cpt_offload == {
        "mem_0": 38912,
        "mem_peak": mem_peaks[3],
        "mem_after_fwd": 43520,
        "mem_after_bwd": 74240,
    }, fairscale_cpt_offload
def test_it(fsdp_config, input_cls):
    """Test FSDP with input being a list or a dict, only single GPU."""
    if torch_version() < (1, 6, 0):
        pytest.skip("older pytorch doesn't support reduce_scatter")

    # Random port in case the next test run quickly, same port would cause conflict.
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(random.randint(2000, 3000))
    torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

    try:
        assert isinstance(fsdp_config, dict), str(fsdp_config)

        class Model(Module):
            def __init__(self):
                super().__init__()
                self.layer = Linear(4, 4)

            def forward(self, input):
                if isinstance(input, list):
                    input = input[0]
                else:
                    assert isinstance(input, dict), input
                    input = input["in"]
                return self.layer(input)

        model = FSDP(Model(), **fsdp_config).cuda()
        optim = SGD(model.parameters(), lr=0.1)

        for _ in range(5):
            in_data = torch.rand(64, 4).cuda()
            in_data.requires_grad = True
            if input_cls is list:
                in_data = [in_data]
            else:
                assert input_cls is dict
                in_data = {"in": in_data}

            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()

        model.assert_state(TrainingState.IDLE)

    finally:
        # Clean-up is important or the next test in this file may fail to init the PG.
        torch.distributed.destroy_process_group()
        del os.environ["MASTER_ADDR"]
        del os.environ["MASTER_PORT"]
Exemple #8
0
def _distributed_worker(gpu_id, world_size, with_fsdp, with_checkpoint, filename, filename_rpc, expected):
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, filename, filename_rpc)
    assert result, "Dist init failed"

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

    model = create_model(with_fsdp, with_checkpoint)
    model = model.cuda()
    if with_fsdp:
        model = to_fsdp(model)
    else:
        model = DistributedDataParallel(model, device_ids=[gpu_id], bucket_cap_mb=500)

    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=1e-4)

    results = {}
    for iteration in range(3):
        get_cur_mem(gpu_id, results, f"iter {iteration}: start")

        out = model(batch)
        get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd")

        out = sum(o.sum() for o in out[0])
        fake_loss = criterion(out, torch.tensor(0.0).cuda())
        get_cur_mem(gpu_id, results, f"iter {iteration}: after loss")

        fake_loss.backward()
        get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd")

        optimizer.step()
        get_cur_mem(gpu_id, results, f"iter {iteration}: after step")

        # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory.
        if torch_version() >= (1, 7, 0):
            model.zero_grad(set_to_none=True)
        else:
            for p in model.parameters():
                p.grad = None
        get_cur_mem(gpu_id, results, f"iter {iteration}: done")

    assert results == expected, f"{results} but expected {expected}"

    teardown()
def test1(precision, flatten):
    if torch_version() < (1, 6, 0):
        pytest.skip("older pytorch doesn't support reduce_scatter")

    temp_file_name = tempfile.mkstemp()[1]
    unused = tempfile.mkstemp()[1]

    fsdp_config = {}
    fsdp_config["mixed_precision"] = precision == "mixed"
    fsdp_config["flatten_parameters"] = flatten == "flatten"

    # Some bugs only show up when we are in world_size > 1 due to sharding changing
    # the tensor dimensions.
    world_size = 2
    mp.spawn(
        _test_func, args=(world_size, fsdp_config, temp_file_name, unused), nprocs=world_size, join=True,
    )
Exemple #10
0
def test_multiple_forward_checkpoint(precision, flatten, wrap_bn):
    mixed_precision = precision == "mixed"
    flatten = flatten == "flatten"
    wrap_bn = wrap_bn == "auto_wrap_bn"
    fp32_reduce_scatter = True if mixed_precision else None

    if torch_version() < (1, 8, 0) and flatten:
        # 1.6 and 1.7 throws this error:
        #   RuntimeError: Trying to backward through the graph a second time, but the saved
        #   intermediate results have already been freed. Specify retain_graph=True when calling
        #   backward the first time.
        pytest.skip("older pytorch throws error when flatten is used")

    world_size = 2
    expected_losses = None
    # Ensure ddp == ddp+ckpt == fsdp == fsdp+ckpt.
    for with_fsdp in [False, True]:
        for with_checkpoint in [False, True]:
            # Get 4 files: 2 for dist_init and 2 for each rank to save the losses.
            with temp_files_ctx(num=2 + world_size) as temp_files:
                mp.spawn(
                    _distributed_worker,
                    (
                        world_size,
                        with_fsdp,
                        with_checkpoint,
                        temp_files,
                        mixed_precision,
                        flatten,
                        wrap_bn,
                        fp32_reduce_scatter,
                    ),
                    nprocs=world_size,
                )
                final_losses = {}
                for rank in range(world_size):
                    with open(temp_files[2 + rank], "rb") as f:
                        final_losses[f"rank_{rank}"] = pickle.load(f)
                if expected_losses is None:
                    expected_losses = final_losses
                else:
                    print(f"fsdp: {with_fsdp} ckpt: {with_checkpoint}")
                    assert objects_are_equal(expected_losses,
                                             final_losses,
                                             raise_exception=True)
Exemple #11
0
 def check(exp, res):
     assert list(exp.keys()) == list(res.keys(
     )), f"{list(exp.keys())} vs. {list(res.keys())}"
     rtol = 1e-4
     atol = 1e-5
     if with_model2 and mixed_precision and torch_version() >= (
             1, 9, 0):
         # On CI, with longer model2, mixed precsion and 1.9, even ddp vs. ddp+ckpt has
         # larger errors.
         rtol = 1e-3
         atol = 1e-4
     for key in exp.keys():
         exp_loss = exp[key]
         res_loss = res[key]
         torch.testing.assert_allclose(exp_loss,
                                       res_loss,
                                       rtol=rtol,
                                       atol=atol)
def test1(temp_files, ddp_ref, precision, flatten):
    if torch_version() < (1, 6, 0):
        pytest.skip("older pytorch doesn't support reduce_scatter")

    state_before, inputs, state_after_fp, state_after_mp = ddp_ref

    if precision == "full":
        state_after = state_after_fp
    else:
        state_after = state_after_mp

    fsdp_config = {}
    fsdp_config["mixed_precision"] = precision == "mixed"
    fsdp_config["flatten_parameters"] = flatten == "flatten"

    if fsdp_config["mixed_precision"] and torch_cuda_version() < (11, 0):
        pytest.skip("Only CUDA 11 is supported with AMP equivalency")

    # Wrap BN half of the time in full precision mode.
    wrap_bn = True
    if random.randint(0, 1) == 0:
        wrap_bn = False
    # Always wrap BN in mixed precision mode.
    if fsdp_config["mixed_precision"]:
        wrap_bn = True

    world_size = _world_size
    mp.spawn(
        _test_func,
        args=(
            world_size,
            fsdp_config,
            wrap_bn,
            None,
            temp_files[0],
            temp_files[1],
            state_before,
            inputs,
            None,
            state_after,
        ),
        nprocs=world_size,
        join=True,
    )
Exemple #13
0
def test_one_iteration(world_size, test_case, fsdp_config):
    """Test FSDP with uneven divide of parameter shards."""
    if torch_version() < (1, 6, 0):
        pytest.skip("older pytorch doesn't support reduce_scatter")

    if world_size > torch.cuda.device_count():
        pytest.skip("Not enough GPUs.")

    temp_file_name = tempfile.mkstemp()[1]
    unused = tempfile.mkstemp()[1]

    # TODO (Min): we may want to extend this to a simple 2 layer model so that it covers
    #             more cases in FSDP. Also, assert_ref_out can be extended to multiple
    #             iterations. This could be a good bootcamp task. I should file a github
    #             issue once we merge.
    model = Linear(3, 3, bias=False)
    mp.spawn(
        _test_func,
        args=(world_size, model, fsdp_config, temp_file_name, unused,
              test_case),
        nprocs=world_size,
        join=True,
    )
def test(world_size, precision, flatten):
    """
    This test simulates wrapping the module after training to run inference.
    This is required in cases where later in a session, the model is wrapped again in FSDP but
    contains nested FSDP wrappers within the module.
    """
    if torch_version() < (1, 6, 0):
        pytest.skip("older pytorch doesn't support reduce_scatter")

    temp_file_name = tempfile.mkstemp()[1]
    unused = tempfile.mkstemp()[1]

    fsdp_config = {
        "mixed_precision": precision == "mixed",
        "flatten_parameters": flatten == "flatten",
    }

    mp.spawn(
        _test_func,
        args=(world_size, fsdp_config, temp_file_name, unused),
        nprocs=world_size,
        join=True,
    )
Exemple #15
0
                return True
        return False

    model.train()
    train_output = model(input)
    assert find_grad_fn(train_output.grad_fn, "CheckpointBackward")
    assert find_grad_fn(train_output.grad_fn, "RecomputeBackward")

    model.eval()
    eval_output = model(input)
    assert not find_grad_fn(eval_output.grad_fn, "CheckpointBackward")
    assert not find_grad_fn(eval_output.grad_fn, "RecomputeBackward")


@torch_spawn([2])
@pytest.mark.xfail(torch_version() < (1, 6, 0), reason="Doesn't work on torch < 1.6.0", strict=True)
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def checkpoint_non_float_input(pipe_class):
    class ForkNonFloat(nn.Module):
        def forward(self, input):
            return (input * 2, torch.tensor([False]))

    class JoinNonFloat(nn.Module):
        def forward(self, input):
            return input[0] * 2

    model = nn.Sequential(ForkNonFloat(), JoinNonFloat())
    model = pipe_class(
        model, balance=[1, 1], worker_map=get_worker_map(), chunks=1, checkpoint="always", pipelined_backward=False,
    )
Exemple #16
0
# limitations under the License.

import pytest
import torch
from torch import nn
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset

from fairscale.experimental.nn.ampnet_pipe.pipe import AMPnetPipe
from fairscale.utils.testing import get_worker_map, torch_spawn, torch_version

# Current on CI, there appears to be a bug with torch 1.8
# See:
# https://app.circleci.com/pipelines/github/facebookresearch/fairscale/1892/workflows/8f658bf4-8052-4084-bb3e-4cc2c445c8aa/jobs/10080/parallel-runs/0/steps/0-112
# So we skip this file in that case until it is fixed.
if torch_version() >= (1, 8, 0):
    pytestmark = pytest.mark.skip


class MySGD(Optimizer):
    r"""
    Args:
        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float): learning rate (required)
    """

    def __init__(self, params, lr=0.01):
        defaults = dict(lr=lr)
        super(MySGD, self).__init__(params, defaults)
import torch.multiprocessing as mp
import torch.nn as nn

from fairscale.experimental.nn.multiprocess_pipe import DistributedLoss, MultiProcessPipe
from fairscale.utils.testing import torch_version

BOUNCE_TENSORS = True

CPU_DEVICES = ["worker0/cpu", "worker1/cpu"]
GPU_DEVICES = ["worker0/cuda:0", "worker1/cuda:1"]
if torch.cuda.is_available():
    DEVICES = [CPU_DEVICES, GPU_DEVICES]
else:
    DEVICES = [CPU_DEVICES]

pytestmark = pytest.mark.skipif(torch_version() < (1, 8, 0),
                                reason="requires torch version >= 1.8.0")


def rpc_worker(rank, world_size, init_file, func, *args):
    if torch_version() == (1, 8, 0):
        if torch.cuda.is_available():
            # Workaround for https://github.com/pytorch/pytorch/issues/53844
            options = rpc.TensorPipeRpcBackendOptions(
                init_method="file://" + init_file, _transports=["ibv", "uv"])
        else:
            # Workaround for https://github.com/pytorch/pytorch/issues/54266
            options = rpc.TensorPipeRpcBackendOptions(
                init_method="file://" + init_file,
                _channels=[
                    "mpt_uv", "basic", "cuda_ipc", "cuda_gdr", "cuda_xth",
Exemple #18
0
                return True
        return False

    model.train()
    train_output = model(input)
    assert find_grad_fn(train_output.grad_fn, "CheckpointBackward")
    assert find_grad_fn(train_output.grad_fn, "RecomputeBackward")

    model.eval()
    eval_output = model(input)
    assert not find_grad_fn(eval_output.grad_fn, "CheckpointBackward")
    assert not find_grad_fn(eval_output.grad_fn, "RecomputeBackward")


@torch_spawn([2])
@pytest.mark.xfail(torch_version() < (1, 6, 0),
                   reason="Doesn't work on torch < 1.6.0",
                   strict=True)
@pytest.mark.parametrize(
    "pipeline_style",
    [MultiProcessPipe.MultiProcess, MultiProcessPipe.AsyncSchedule])
def checkpoint_non_float_input(pipeline_style):
    class ForkNonFloat(nn.Module):
        def forward(self, input):
            return (input * 2, torch.tensor([False]))

    class JoinNonFloat(nn.Module):
        def forward(self, input):
            return input[0] * 2

    model = nn.Sequential(ForkNonFloat(), JoinNonFloat())
Exemple #19
0
def _distributed_worker(gpu_id, world_size, with_fsdp, with_checkpoint,
                        filename, filename_rpc, expected, model_hidden_dim,
                        fsdp_config):
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, filename, filename_rpc)
    assert result, "Dist init failed"

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True

    # Note that FSDP auto-cast the input in AMP mode. So we don't need to call half() here.
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

    model = create_model(with_fsdp, with_checkpoint, model_hidden_dim,
                         fsdp_config)
    model = model.cuda()
    if with_fsdp:
        model = to_fsdp(model, fsdp_config)
    else:
        model = DistributedDataParallel(model,
                                        device_ids=[gpu_id],
                                        bucket_cap_mb=500)

    # We enable momentum so that after the first iteration, the optimizer state is added
    # to the total memory used.
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)

    # Set AMP context if needed.
    context = contextlib.suppress()
    if "mixed_precision" in fsdp_config and fsdp_config["mixed_precision"]:
        context = torch.cuda.amp.autocast(enabled=True)

    # We have observed that sometimes after 3rd iteration, 4th one can fail (not on this
    # test but on much bigger scale tests). We run 4 iterations here just in case it happens.
    iterations = 4

    results = {}  # results of memory stats
    for iteration in range(iterations):
        get_cur_mem(gpu_id, results, f"iter {iteration}: start")

        with context:
            out = model(batch)
            get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd")

            out = sum(o.sum() for o in out[0])
            fake_loss = criterion(out, torch.tensor(0.0).cuda())
            get_cur_mem(gpu_id, results, f"iter {iteration}: after loss")

        fake_loss.backward()
        get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd")

        optimizer.step()
        get_cur_mem(gpu_id, results, f"iter {iteration}: after step")

        # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory.
        if torch_version() >= (1, 7, 0):
            model.zero_grad(set_to_none=True)
        else:
            for p in model.parameters():
                p.grad = None
        get_cur_mem(gpu_id, results, f"iter {iteration}: done")

    dump_all_tensors(gpu_id)
    print(results)

    def cmp(results, expected):
        ret = ""
        assert results.keys() == expected.keys(
        ), f"{list(results.keys())} vs. {list(expected.keys())}"
        for k, v in results.items():
            exp = expected[k]
            if abs(exp - v) > 1:  # allow 1MB rounding differences
                ret += f"{k}: got {v}, expected {exp}\n"
        return ret

    output = cmp(results, expected)
    assert not output, output

    teardown()
Exemple #20
0
import torch
from torch import nn

from fairscale.nn.model_parallel.initialize import (
    destroy_model_parallel,
    get_pipeline_parallel_group,
    initialize_model_parallel,
)
from fairscale.nn.pipe import AsyncPipe, LazyModule, MultiProcessPipe
from fairscale.utils.testing import get_worker_map, torch_spawn, torch_version

# Current on CI, there appears to be a bug with torch 1.8
# See:
# https://app.circleci.com/pipelines/github/facebookresearch/fairscale/1892/workflows/8f658bf4-8052-4084-bb3e-4cc2c445c8aa/jobs/10080/parallel-runs/0/steps/0-112
# So we skip this file in that case until it is fixed.
if torch_version() >= (1, 8, 0):
    pytestmark = pytest.mark.skip


@torch_spawn([2])
@pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe])
def parameters(pipe_class):
    model = nn.Sequential(nn.Linear(1, 1))
    pipe = pipe_class(model,
                      balance=[1],
                      worker_map=get_worker_map(),
                      chunks=1)
    if torch.distributed.get_rank() == 0:
        assert list(pipe.parameters()) != []
    else:
        assert list(pipe.parameters()) == []
Exemple #21
0
    def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer],
                                    change_train_graph: bool = False):
        # Any model works. Add one different buffer per rank
        trunk = torch.nn.Sequential(torch.nn.Linear(in_channels, hidden),
                                    torch.nn.Linear(hidden, hidden),
                                    torch.nn.Linear(hidden, hidden))
        trunk.register_buffer("test_buffer", torch.ones((1)) * rank)
        trunk.to(device)

        head = torch.nn.Linear(hidden, out_channels).to(device)

        # Define a model to be trained by OSS
        oss_module = torch.nn.Sequential(trunk, head)

        # Make sure that the param groups are interleaved, to catch an ordering bug in the state dict
        oss_trainable_params = [
            {
                "params":
                list(trunk.parameters())[:-1] + list(head.parameters()),
                "lr": 1e-5
            },
            {
                "params": list(trunk.parameters())[-1],
                "lr": 1e-4
            },
        ]

        optimizer_settings: Dict[Any, Any] = {}
        if isinstance(optimizer, torch.optim.SGD):
            optimizer_settings["momentum"] = 0.9

        sharded_optimizer = optim.OSS(
            params=oss_trainable_params,
            optim=optimizer,
            group=None,
            broadcast_buffer_size=2**10,
            **optimizer_settings,
        )

        oss_ddp_model = DDP(module=oss_module,
                            device_ids=[rank],
                            broadcast_buffers=True,
                            find_unused_parameters=True)

        # Define a model to be trained by normal pytorch + DDP
        ddp_trunk = copy.deepcopy(trunk)
        ddp_head = copy.deepcopy(head)
        ddp_module = torch.nn.Sequential(ddp_trunk, ddp_head)

        ddp_trainable_params = [
            {
                "params":
                list(ddp_trunk.parameters())[:-1] +
                list(ddp_head.parameters()),
                "lr":
                1e-5
            },
            {
                "params": list(ddp_trunk.parameters())[-1],
                "lr": 1e-4
            },
        ]
        ddp_optimizer = optimizer(ddp_trainable_params,
                                  **optimizer_settings)  # type: ignore
        ddp_model = DDP(module=ddp_module,
                        device_ids=[rank],
                        broadcast_buffers=True,
                        find_unused_parameters=True)

        def check_step():
            input_tensor = torch.rand((batch, in_channels)).to(device)

            def closure_ddp(input_tensor=input_tensor):
                ddp_optimizer.zero_grad()
                ddp_loss = ddp_model(input_tensor).abs().sum()
                ddp_loss.backward()
                return ddp_loss

            def closure_sharded(input_tensor=input_tensor):
                sharded_optimizer.zero_grad()
                sharded_loss = oss_ddp_model(input_tensor).abs().sum()
                sharded_loss.backward()
                return sharded_loss

            loss_ddp = cast(torch.Tensor,
                            ddp_optimizer.step(closure=closure_ddp))
            loss_sharded_optim = cast(
                torch.Tensor, sharded_optimizer.step(closure=closure_sharded))

            assert torch.allclose(
                loss_ddp, loss_sharded_optim, rtol=1e-3
            ), f"Losses differ in between Pytorch optim and OSS\n {loss_ddp.item()} - {loss_sharded_optim.item()} - world size {world_size}"

            check_same_model_params(oss_ddp_model, ddp_model)

        # The model should be synchronized in between the ranks at construction time, check that
        check_same_model_params(oss_ddp_model, ddp_model)

        # The models should stay the same in between ddp and sharded optimizer
        for i in range(5):
            check_step()

            # Check that altering the trainable parameters does not cause DDP and OSS to diverge
            if change_train_graph:
                # Flip the first parameter from trainable to non-trainable and vice-versa
                next(ddp_module.parameters()).requires_grad = not next(
                    ddp_module.parameters()).requires_grad
                next(oss_module.parameters()).requires_grad = not next(
                    oss_module.parameters()).requires_grad
                # sharded_optimizer.refresh_trainable()

        # Check that the checkpoints are compatible (post pytorch 1.5)
        if torch_version()[1] > 5:
            # - get states
            ddp_state_dict = ddp_optimizer.state_dict()
            sharded_optimizer.consolidate_state_dict(
                recipient_rank=RECIPIENT_RANK)
            sharded_optim_state_dict = sharded_optimizer.state_dict(
            ) if rank == RECIPIENT_RANK else {}
            sharded_optim_state_dict = sync_object_ranks(
                sharded_optim_state_dict, RECIPIENT_RANK, device)

            # - cross load the states
            # run one step and check that the models are still the same
            ddp_state_dict_ref = copy.deepcopy(
                ddp_state_dict)  # OSS will remove some states
            ddp_optimizer.load_state_dict(
                sharded_optim_state_dict)  # mixup on purpose !
            sharded_optimizer.load_state_dict(ddp_state_dict)
            check_step()

            #  - self load, rewind, check no problem
            # run one step and check that the models are still the same
            ddp_optimizer.load_state_dict(ddp_state_dict_ref)
            sharded_optimizer.load_state_dict(sharded_optim_state_dict)
            check_step()
Exemple #22
0
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import functools
import tempfile

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from fairscale.nn import MOELayer, Top2Gate
from fairscale.utils.testing import torch_version

pytestmark = pytest.mark.skipif(
    not (torch.cuda.is_available() and torch_version() >= (1, 8, 0)),
    reason="cuda and torch>=1.8.0 required")

devices = ["cuda"]


def pg_worker(rank, world_size, init_file, func, *args):
    init_url = "file://" + init_file
    dist.init_process_group(backend=dist.Backend.NCCL,
                            rank=rank,
                            world_size=world_size,
                            init_method=init_url)
    torch.cuda.set_device(rank)
    dist.all_reduce(torch.zeros(1).cuda())
    func(*args)
    dist.destroy_process_group()
Exemple #23
0
def test_multiple_forward_checkpoint(precision, flatten, wrap_bn, model_type,
                                     bn_type):
    mixed_precision = precision == "mixed"
    flatten = flatten == "flatten"
    wrap_bn = wrap_bn == "auto_wrap_bn"
    fp32_reduce_scatter = True if mixed_precision else None
    with_model2 = model_type == "model2"
    with_sync_bn = bn_type == "sync_bn"

    if torch_version() >= (1, 7, 0) and torch_version() < (1, 8,
                                                           0) and with_sync_bn:
        # SyncBN is buggy in 1.7, errors like:
        # E         File "/home/circleci/venv/lib/python3.8/site-packages/torch/nn/modules/_functions.py", line 13, in forward
        # E           dtype=running_mean.dtype,
        # E       AttributeError: 'NoneType' object has no attribute 'dtype'
        pytest.skip("SyncBatchNorm in 1.7 is buggy")

    if with_sync_bn and not wrap_bn:
        pytest.skip("SyncBatchNorm requires auto_wrap_bn")

    if torch_version() < (1, 8, 0) and flatten:
        # 1.6 and 1.7 throws this error:
        #   RuntimeError: Trying to backward through the graph a second time, but the saved
        #   intermediate results have already been freed. Specify retain_graph=True when calling
        #   backward the first time.
        pytest.skip("older pytorch throws error when flatten is used")

    world_size = 2
    expected_losses = None
    # Ensure ddp == ddp+ckpt == fsdp == fsdp+ckpt.
    for with_fsdp in [False, True]:
        for with_checkpoint in [False, True]:
            if not with_fsdp and with_checkpoint:
                continue
            final_losses = _get_cached_results(
                world_size,
                with_model2,
                with_sync_bn,
                with_fsdp,
                with_checkpoint,
                mixed_precision,
                flatten,
                wrap_bn,
                fp32_reduce_scatter,
            )
            if expected_losses is None:
                expected_losses = final_losses
            else:
                print(
                    f"checking: fsdp {with_fsdp} ckpt {with_checkpoint} with ddp+no_ckpt"
                )

                def check(exp, res):
                    assert list(exp.keys()) == list(res.keys(
                    )), f"{list(exp.keys())} vs. {list(res.keys())}"
                    rtol = 1e-4
                    atol = 1e-5
                    if with_model2 and mixed_precision and torch_version() >= (
                            1, 9, 0):
                        # On CI, with longer model2, mixed precsion and 1.9, even ddp vs. ddp+ckpt has
                        # larger errors.
                        rtol = 1e-3
                        atol = 1e-4
                    for key in exp.keys():
                        exp_loss = exp[key]
                        res_loss = res[key]
                        torch.testing.assert_allclose(exp_loss,
                                                      res_loss,
                                                      rtol=rtol,
                                                      atol=atol)

                check(expected_losses, final_losses)
Exemple #24
0
    def run(compute_cycles, all_gather_cycles):
        has_params = all_gather_cycles > 0
        model = _create_model(fsdp_config, compute_cycles, has_params)

        # Get the input and sets the input's requires_grad to True because
        # we have a fake compute in the forward pass.
        batch = torch.rand(1).cuda()
        batch.requires_grad = True

        # We run 20 iterations but only collect timing data from the minimal 10
        # data points because nondeterministic system events can disturb the timing.
        cpu_iter = Min10()
        cpu_wait = Min10()
        gpu_compute = Min10()
        gpu_total = Min10()
        for _ in range(20):
            # Get two events for measuring the overall time.
            e1 = Event(enable_timing=True)
            e2 = Event(enable_timing=True)

            cpu_start = time.process_time()

            all_gather_called = False

            def _delayed_all_gather(*args, **kwargs):
                nonlocal all_gather_called
                all_gather_called = True
                torch.cuda._sleep(all_gather_cycles)
                return orig_all_gather(*args, **kwargs)

            # forward pass
            #
            # Even though both e1 & e2 are on the compute stream, since
            # compute depends on all_gather, e2-e1 includes all_gather time.
            e1.record()
            with patch("torch.distributed.all_gather", _delayed_all_gather):
                out = model(batch)
                if has_params and world_size > 1:
                    assert all_gather_called
                else:
                    assert not all_gather_called
            e2.record()

            # backward pass
            out.backward()
            if torch_version() >= (1, 7, 0):
                model.zero_grad(set_to_none=True)
            else:
                for p in model.parameters():
                    p.grad = None

            cpu_iter_time = time.process_time() - cpu_start

            # wait for gpu
            out.item()
            cpu_wait_for_gpu_time = time.process_time() - cpu_start - cpu_iter_time

            # get sum of the compute time
            times = []
            for mod in model.modules():
                if not isinstance(mod, Layer):
                    continue
                times.append(mod.get_time())

            # get gpu compute + all_gather time
            overall_gpu_time = e1.elapsed_time(e2)

            cpu_iter.add(cpu_iter_time)
            cpu_wait.add(cpu_wait_for_gpu_time)
            gpu_compute.add(sum(times))
            gpu_total.add(overall_gpu_time)

        del model
        return {
            "cpu_iter": cpu_iter.avg(),
            "cpu_wait": cpu_wait.avg(),
            "gpu_compute": gpu_compute.avg(),
            "gpu_total": gpu_total.avg(),
        }
Exemple #25
0
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.nn as nn

from fairscale.experimental.nn.distributed_pipeline import DistributedLoss, DistributedPipeline, PipelineModulesGraph
from fairscale.utils.testing import torch_version

CPU_DEVICES = ["worker0/cpu", "worker1/cpu"]
GPU_DEVICES = ["worker0/cuda:0", "worker1/cuda:1"]
if torch.cuda.is_available():
    DEVICES = [CPU_DEVICES, GPU_DEVICES]
else:
    DEVICES = [CPU_DEVICES]


pytestmark = pytest.mark.skipif(torch_version() < (1, 9, 0), reason="requires torch version >= 1.9.0")


def rpc_worker(rank, world_size, init_file, func, *args):
    options = rpc.TensorPipeRpcBackendOptions(init_method="file://" + init_file)
    for i in range(world_size):
        options.set_device_map("worker" + str(i), {rank: i})
    rpc.init_rpc(
        "worker" + str(rank),
        rank=rank,
        world_size=world_size,
        backend=rpc.BackendType.TENSORPIPE,
        rpc_backend_options=options,
    )
    if rank == 0:
        func(*args)
def run_ddp_parity(
    rank,
    world_size,
    backend,
    temp_file_name,
    reduce_buffer_size,
    grad_accumulation,
    change_train_graph,
    fp16_reduction,
    clip_grad_norm,
    amp,
    manual_reduction,
):
    dist.init_process_group(init_method="file://" + temp_file_name,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)

    device = torch.device("cuda")
    torch.cuda.set_device(rank)
    torch.manual_seed(rank)
    np.random.seed(rank)
    NUMBER_BATCHS = 5
    BATCH_SIZE = 8

    # Test all combinations: AMP, Accumulate, Change train graph, reduce buckets
    print(
        f"{rank}: Checking configuration: accumulate {grad_accumulation}" +
        f" - change train graph {change_train_graph}" + f" - amp {amp}" +
        f" - manual reduction {manual_reduction}" +
        f" - buffers {reduce_buffer_size}",
        flush=True,
    )

    # The API should be the exact same in between the sharded and non-sharded variants, generic closure
    def closure(model,
                scaler,
                input_tensor,
                should_accumulate,
                _manual_reduction=False):
        accumulate_steps = 3 if should_accumulate else 1

        model.zero_grad()

        def step():
            if scaler is not None:
                with torch.cuda.amp.autocast():
                    loss = model(input_tensor).abs().sum()
                    scaler.scale(loss).backward()
            else:
                loss = model(input_tensor).abs().sum()
                loss.backward()

        with model.no_sync() if should_accumulate else suppress():
            for _ in range(accumulate_steps - 1):
                step()

        if not _manual_reduction:
            step()
        else:
            with model.no_sync():
                step()

            model.reduce()

    # Any model works. Add one different buffer per rank
    model = _get_mlp()
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)

    # Make sure that the model starts with non-trainable, so that we check for the buckets to be
    # properly reassigned when/if this changes
    next(model.parameters()).requires_grad = False

    sharded_optimizer = OSS(params=model.parameters(),
                            optim=torch.optim.SGD,
                            lr=1e-4,
                            momentum=0.99)
    sharded_ddp_model = ShardedDataParallel(
        module=model,
        sharded_optimizer=sharded_optimizer,
        broadcast_buffers=True,
        reduce_buffer_size=reduce_buffer_size,
        reduce_fp16=fp16_reduction,
    )

    ddp_model_single = copy.deepcopy(model)
    ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(),
                                    lr=1e-4,
                                    momentum=0.99)
    ddp_model = DDP(ddp_model_single,
                    device_ids=[rank],
                    broadcast_buffers=True,
                    find_unused_parameters=True)

    if fp16_reduction:
        from dist.algorithms.ddp_com_hooks.default_hooks import fp16_compress_hook

        ddp_model.register_comm_hook(state=None,
                                     hook=fp16_compress_hook)  # type: ignore

    ddp_scaler = TorchGradScaler() if amp else None
    sharded_scaler = ShardedGradScaler() if amp else None

    # The model should be synchronized in between the ranks at construction time, check that
    check_same_model_params(sharded_ddp_model, ddp_model)

    # Typical training loop, check that we get the exact same results as DDP
    for i in range(NUMBER_BATCHS):
        input_tensor = torch.rand((BATCH_SIZE, 2)).to(device)

        def ddp_closure(input_tensor=input_tensor):
            return closure(ddp_model, ddp_scaler, input_tensor,
                           grad_accumulation)

        def sharded_closure(input_tensor=input_tensor):
            return closure(
                sharded_ddp_model,
                sharded_scaler,
                input_tensor,
                grad_accumulation,
                _manual_reduction=manual_reduction,
            )

        # Step/scale both
        for _scaler, _closure, _optimizer in (
            (ddp_scaler, ddp_closure, ddp_optimizer),
            (sharded_scaler, sharded_closure, sharded_optimizer),
        ):
            if _scaler is not None:
                _ = _closure(input_tensor)
                _scaler.step(_optimizer)
                _scaler.update()
            else:
                _optimizer.step(_closure())

        check_same_model_params(sharded_ddp_model, ddp_model,
                                f"Rank: {rank} - Step {i} broke")

        # Check that the two grad norm are equivalent
        # NOTE: The grads can occasionally be NaNs, the scaler will skip the step in that case
        # This is not ShardedDDP specific. If the grads are not NaN for DDP then they should also
        # be valid for ShardedDDP
        # NOTE: DDP does not handle parameters trainability being changed after the fact, see
        # https://github.com/pytorch/pytorch/blob/5781aec74ef00284e0262817a649278c2e8072bf/torch/nn/parallel/distributed.py#L471
        if clip_grad_norm and not change_train_graph:
            if torch_version() >= (1, 9, 0):
                total_norm = torch.nn.utils.clip_grad_norm_(
                    ddp_model.parameters(),
                    0.3,
                    norm_type=2.0,
                    error_if_nonfinite=False)  # type: ignore
            else:
                total_norm = torch.nn.utils.clip_grad_norm_(
                    ddp_model.parameters(), 0.3, norm_type=2.0)  # type: ignore
            if not torch.isnan(total_norm):
                oss_total_norm = sharded_optimizer.clip_grad_norm(
                    0.3, norm_type=2.0)
                allclose = torch.allclose(oss_total_norm,
                                          total_norm,
                                          atol=1e-2 if amp else 1e-8)

                if not allclose:
                    # Debug helper if this unit test does not pass, compare the gradients in between DDP and ShardedDDP
                    for idx, (p_ddp, p_sdp) in enumerate(
                            zip(ddp_model.parameters(),
                                sharded_ddp_model.parameters())):
                        if p_ddp.grad is not None:
                            if p_sdp.grad is not None:
                                print(rank,
                                      idx,
                                      torch.norm(p_ddp.grad),
                                      torch.norm(p_sdp.grad),
                                      flush=True)
                            else:
                                print(rank,
                                      idx,
                                      torch.norm(p_ddp.grad),
                                      "not owned",
                                      flush=True)

                assert (
                    allclose
                ), f"torch and fairscale should return the same grad norm\n {oss_total_norm} vs {total_norm}"
            else:
                print(rank, "NaN grad norm in DDP", flush=True)

        # Flip the trainability of the first parameter back and forth
        if i == 0 and change_train_graph:
            next(sharded_ddp_model.parameters()).requires_grad = not next(
                sharded_ddp_model.parameters()).requires_grad
            next(ddp_model.parameters()).requires_grad = not next(
                ddp_model.parameters()).requires_grad
            check_same_model_params(
                sharded_ddp_model, ddp_model,
                f"Rank: {rank} - Trainability refresh {i} broke")

    dist.destroy_process_group()