def test_norm(device, norm_type, mixed_precision):
    """Test checkpoint_wrapper with different norm layers."""
    if device == "cuda" and not torch.cuda.is_available():
        pytest.skip("Skip due to lack of GPU")

    # Get input, ref, checkpoint models and make them equal.
    in_data = torch.rand(2, 2, 3, 3).to(device)
    m_ref = get_model(norm_type, False, mixed_precision).to(device)
    m_cpt = get_model(norm_type, True, mixed_precision).to(device)
    m_cpt.load_state_dict(m_ref.state_dict())

    if torch_version() >= (1, 6, 0):
        # This assert fails on 1.5.1.
        assert objects_are_equal(m_ref.state_dict(), m_cpt.state_dict())

    if mixed_precision != "fp32":
        in_data = in_data.half()

    # Needed due to checkpointing.
    in_data.requires_grad = True
    for model in (m_ref, m_cpt):
        optim = SGD(model.parameters(), lr=0.1)
        if device == "cpu" and mixed_precision != "fp32":
            # Got: RuntimeError: "batch_norm"/"layer_norm" not implemented for 'Half'.
            with pytest.raises(RuntimeError):
                out = model(in_data)
            return
        else:
            # Everything else work.
            out = model(in_data)
        out.sum().backward()
        optim.step()

    if torch_version() >= (1, 6, 0):
        assert objects_are_equal(m_ref.state_dict(), m_cpt.state_dict())
Exemplo n.º 2
0
def test_single_run():
    if torch_version() < (1, 8, 0):
        pytest.skip("requires torch version >= 1.8.0")
    from fairscale.experimental.nn.auto_shard import shard_model

    model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout)
    sharded_model = shard_model(model)
    assert len(sharded_model) == 2, "Length is sharded model is incorrect."
    expected_param_nums = [5998600, 5785383]
    for i, model in enumerate(sharded_model):
        param_count = {}
        for name, module in model.named_modules():
            if "." in name:
                continue

            param_count[name] = sum([x.numel() for x in module.parameters()])
        assert expected_param_nums[i] == param_count[""]

    src_mask = torch.randn((35, 35), dtype=torch.float32)
    src = torch.randint(1, ntokens, (35, 20))
    input = [src, src_mask]
    for model in sharded_model:
        if type(input) == list:
            input = model(*input)
        else:
            input = model(input)

    assert input.size() == torch.Size([35, 20, 28783])
Exemplo n.º 3
0
def test_correctness(use_fp16, checkpoint_activation, num_microbatches,
                     use_auto_shard):
    if use_auto_shard and torch_version() < (1, 8, 0):
        pytest.skip("auto_shard requires torch version >= 1.8.0")

    if (use_fp16 or checkpoint_activation) and not hasattr(
            torch.cuda.amp, "custom_fwd"):
        pytest.skip(
            f"AMP APIs are not supported in torch version {torch.__version__}")

    if not checkpoint_activation and num_microbatches > 1:
        pytest.skip("We only support microbatches with activation offloading.")

    device, offload_device = _init()
    model = _get_model()
    if use_auto_shard:
        offload_model = shard_model(model)
    else:
        offload_model = model

    rmodel, ropt, rloss = _train_reg_model(model, device, offload_device)
    omodel, oopt, oloss = _train_offload_model(
        offload_model,
        device,
        offload_device,
        use_fp16=use_fp16,
        checkpoint_activation=checkpoint_activation,
        num_microbatches=num_microbatches,
    )
    _check_parity(rmodel.cpu(), omodel.cpu(), ropt, oopt, rloss, oloss)
Exemplo n.º 4
0
def test_smaller_than_world_size(world_size, test_case, fsdp_config):
    """Test FSDP with uneven divide of parameter shards."""
    if torch_version() < (1, 6, 0):
        pytest.skip("older pytorch doesn't support reduce_scatter in gloo backend")

    if world_size > torch.cuda.device_count():
        pytest.skip("Not enough GPUs.")

    temp_file_name = tempfile.mkstemp()[1]
    unused = tempfile.mkstemp()[1]

    model = Sequential(
        Linear(3, 3, bias=False),
        Linear(3, 4, bias=False),
        Linear(4, 5, bias=False),
        Linear(5, 4, bias=False),
        Linear(4, 3, bias=False),
        Linear(3, 1, bias=False),
        Linear(1, 1, bias=False),  # param here is smaller than world_size if unflattened.
    )
    mp.spawn(
        _test_func,
        args=(world_size, model, fsdp_config, temp_file_name, unused, test_case),
        nprocs=world_size,
        join=True,
    )
Exemplo n.º 5
0
 def setUp(self):
     if torch_version() < (1, 6, 0):
         raise unittest.SkipTest("Need pytorch version >= 1.6 due to lack of reduce_scatter")
     if not torch.cuda.is_available():
         raise unittest.SkipTest("CUDA not available, skipping test")
     if sys.platform == "win32":
         raise unittest.SkipTest("NCCL doesn't support Windows, skipping test")
     if torch.cuda.device_count() < 2:
         raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping")
Exemplo n.º 6
0
def test_regnet(temp_files, ddp_ref, precision, flatten, sync_bn):
    if torch_version() < (1, 6, 0):
        pytest.skip("older pytorch doesn't support reduce_scatter")

    state_before, inputs, conv_bias, linear_bias, state_after = ddp_ref

    state_after = state_after[(precision, sync_bn)]

    fsdp_config = {}
    fsdp_config["mixed_precision"] = precision == "mixed"
    fsdp_config["flatten_parameters"] = flatten == "flatten"

    # When linear bias is True, DDP's AMP O1 and FSDP's default AMP O1.5 is different,
    # we force FSDP to use AMP O1 here by setting compute_dtype to float32.
    if linear_bias:
        fsdp_config["compute_dtype"] = torch.float32

    if fsdp_config["mixed_precision"] and torch_cuda_version() < (11, 0):
        pytest.skip("Only CUDA 11 is supported with AMP equivalency")

    # Wrap BN half of the time.
    wrap_bn = True
    if random.randint(0, 1) == 0:
        wrap_bn = False
    # Except, always wrap BN in mixed precision + sync_bn mode, due to error of sync_bn wrapping,
    # regardless of compute_dtype.
    if fsdp_config["mixed_precision"] and sync_bn != "none":
        wrap_bn = True

    # When BN is not wrapped (i.e. not in full precision), FSDP's compute_dtype needs to
    # be fp32 to match DDP (otherwise, numerical errors happen on BN's running_mean/running_var
    # buffers).
    if fsdp_config["mixed_precision"] and not wrap_bn:
        fsdp_config["compute_dtype"] = torch.float32

    world_size = _world_size
    mp.spawn(
        _distributed_worker,
        args=(
            world_size,
            fsdp_config,
            wrap_bn,
            None,
            temp_files[0],
            temp_files[1],
            state_before,
            inputs,
            None,
            state_after,
            sync_bn,
            conv_bias,
            linear_bias,
        ),
        nprocs=world_size,
        join=True,
    )
Exemplo n.º 7
0
def test_basic(device):
    if "cuda" in device and not torch.cuda.is_available():
        pytest.skip("test requires a GPU")

    input = torch.rand(2, 16, 32).requires_grad_(True)
    model = BasicModel().to(device)
    no_cpt = get_loss_and_gnorm(model, input.to(device))

    model = BasicModel(use_pytorch_checkpoint=True).to(device)
    pyt_cpt = get_loss_and_gnorm(model, input.to(device))

    model = BasicModel(use_fairscale_checkpoint=True).to(device)
    fairscale_cpt = get_loss_and_gnorm(model, input.to(device))

    model = BasicModel(use_fairscale_checkpoint=True, offload_to_cpu=True).to(device)
    fairscale_cpt_offload = get_loss_and_gnorm(model, input.to(device))

    # Check for correctness.
    for key in "loss", "gnorm":
        if not (no_cpt[key] == pyt_cpt[key] == fairscale_cpt[key] == fairscale_cpt_offload[key]):
            print(no_cpt, pyt_cpt, fairscale_cpt, fairscale_cpt_offload)
            assert 0
        del no_cpt[key]
        del pyt_cpt[key]
        del fairscale_cpt[key]
        del fairscale_cpt_offload[key]

    # Check for memory usage for cuda only.
    if "cpu" in device:
        return

    mem_peaks = [98816, 103424, 103424, 107520]
    if torch_version() < (1, 7, 0):
        # Older torch behaves slightly differently
        mem_peaks = [102400, 103424, 103424, 107520]

    assert no_cpt == {"mem_0": 38912, "mem_peak": mem_peaks[0], "mem_after_fwd": 64000, "mem_after_bwd": 74240}, no_cpt
    assert pyt_cpt == {
        "mem_0": 38912,
        "mem_peak": mem_peaks[1],
        "mem_after_fwd": 43520,
        "mem_after_bwd": 74240,
    }, pyt_cpt
    assert fairscale_cpt == {
        "mem_0": 38912,
        "mem_peak": mem_peaks[2],
        "mem_after_fwd": 43520,
        "mem_after_bwd": 74240,
    }, fairscale_cpt
    assert fairscale_cpt_offload == {
        "mem_0": 38912,
        "mem_peak": mem_peaks[3],
        "mem_after_fwd": 43520,
        "mem_after_bwd": 74240,
    }, fairscale_cpt_offload
Exemplo n.º 8
0
def test_torch_version():
    assert torch_version("") == tuple()
    assert torch_version("bad format") == tuple()
    assert torch_version("1.9.0") == (1, 9, 0)
    assert torch_version("1.10.0a0+gitbc6fc3e") == (1, 10, 0)
    assert torch_version("1.7.0+cu102") == (1, 7, 0)
    assert torch_version("1.10.0a0+fb") == (1, 10, 0)
Exemplo n.º 9
0
def test_input_type(temp_files, fsdp_config, input_cls):
    """Test FSDP with input being a list or a dict, only single GPU."""

    if torch_version() < (1, 7, 0):
        # This test runs multiple test cases in a single process. On 1.6.0 it
        # throw an error like this:
        #     RuntimeError: Container is already initialized! Cannot initialize it twice!
        pytest.skip(
            "older pytorch doesn't work well with single process dist_init multiple times"
        )

    result = dist_init(rank=0,
                       world_size=1,
                       filename=temp_files[0],
                       filename_rpc=temp_files[1])
    assert result, "Dist init failed"

    assert isinstance(fsdp_config, dict), str(fsdp_config)

    class Model(Module):
        def __init__(self):
            super().__init__()
            self.layer = Linear(4, 4)

        def forward(self, input):
            if isinstance(input, list):
                input = input[0]
            else:
                assert isinstance(input, dict), input
                input = input["in"]
            return self.layer(input)

    model = FSDP(Model(), **fsdp_config).cuda()
    optim = SGD(model.parameters(), lr=0.1)

    for _ in range(5):
        in_data = torch.rand(64, 4).cuda()
        in_data.requires_grad = True
        if input_cls is list:
            in_data = [in_data]
        else:
            assert input_cls is dict
            in_data = {"in": in_data}

        out = model(in_data)
        out.sum().backward()
        optim.step()
        optim.zero_grad()

    model.assert_state(TrainingState.IDLE)

    teardown()
Exemplo n.º 10
0
def test_train_and_eval_with_checkpointing():
    if torch_version() < (1, 6, 0):
        pytest.skip("older pytorch doesn't support reduce_scatter")

    world_size = 2

    with temp_files_ctx(2) as (temp_file_name, unused):
        mp.spawn(
            _test_func,
            args=(world_size, temp_file_name, unused),
            nprocs=world_size,
            join=True,
        )
Exemplo n.º 11
0
def test_dynaimc_conditionals_auto_wrapped():
    if torch_version() < (1, 8, 0):
        pytest.skip("requires torch version >= 1.8.0")
    from fairscale.experimental.nn.auto_shard import shard_model

    features = 10

    model = BranchedNetwork(features)
    sharded_model = shard_model(model, 3)
    assert len(sharded_model) == 3

    input_ = torch.randn(3, features)
    model_output = model(input_)
    sharded_model_output = input_
    for shard in sharded_model:
        sharded_model_output = shard(sharded_model_output)
    assert torch.allclose(model_output, sharded_model_output)
 def check(exp, res):
     assert list(exp.keys()) == list(
         res.keys()), f"{list(exp.keys())} vs. {list(res.keys())}"
     rtol = 1e-4
     atol = 1e-5
     if with_model2 and mixed_precision and torch_version() >= (
             1, 9, 0):
         # On CI, with longer model2, mixed precsion and 1.9, even ddp vs. ddp+ckpt has
         # larger errors.
         rtol = 1e-3
         atol = 1e-4
     for key in exp.keys():
         exp_loss = exp[key]
         res_loss = res[key]
         torch.testing.assert_allclose(exp_loss,
                                       res_loss,
                                       rtol=rtol,
                                       atol=atol)
Exemplo n.º 13
0
def test1(precision, flatten):
    if torch_version() < (1, 6, 0):
        pytest.skip("older pytorch doesn't support reduce_scatter")

    temp_file_name = tempfile.mkstemp()[1]
    unused = tempfile.mkstemp()[1]

    fsdp_config = {}
    fsdp_config["mixed_precision"] = precision == "mixed"
    fsdp_config["flatten_parameters"] = flatten == "flatten"

    # Some bugs only show up when we are in world_size > 1 due to sharding changing
    # the tensor dimensions.
    world_size = 2
    mp.spawn(
        _test_func,
        args=(world_size, fsdp_config, temp_file_name, unused),
        nprocs=world_size,
        join=True,
    )
Exemplo n.º 14
0
def test_one_iteration(world_size, test_case, fsdp_config):
    """Test FSDP with uneven divide of parameter shards."""
    if torch_version() < (1, 6, 0):
        pytest.skip("older pytorch doesn't support reduce_scatter")

    if world_size > torch.cuda.device_count():
        pytest.skip("Not enough GPUs.")

    temp_file_name = tempfile.mkstemp()[1]
    unused = tempfile.mkstemp()[1]

    # TODO (Min): we may want to extend this to a simple 2 layer model so that it covers
    #             more cases in FSDP. Also, assert_ref_out can be extended to multiple
    #             iterations. This could be a good bootcamp task. I should file a github
    #             issue once we merge.
    model = Linear(3, 3, bias=False)
    mp.spawn(
        _test_func,
        args=(world_size, model, fsdp_config, temp_file_name, unused, test_case),
        nprocs=world_size,
        join=True,
    )
Exemplo n.º 15
0
def test_ddp_parity(
    reduce_buffer_size,
    grad_accumulation,
    change_train_graph,
    fp16_reduction,
    clip_grad_norm,
    amp,
    manual_reduction,
    multiple_fw,
):
    if torch_version() < (1, 8, 0):
        pytest.skip("pytorch version >= 1.8.0 required")
    if manual_reduction and change_train_graph:
        pytest.skip(
            "Skipping changing model and grad accumulation combination, makes little sense"
        )

    world_size = torch.cuda.device_count()
    backend = dist.Backend.NCCL
    with temp_files_ctx(num=1) as temp_files:
        mp.spawn(
            run_ddp_parity,
            args=(
                world_size,
                backend,
                temp_files[0],
                reduce_buffer_size,
                grad_accumulation,
                change_train_graph,
                fp16_reduction,
                clip_grad_norm,
                amp,
                manual_reduction,
                multiple_fw,
            ),
            nprocs=world_size,
            join=True,
        )
Exemplo n.º 16
0
def test(world_size, precision, flatten):
    """
    This test simulates wrapping the module after training to run inference.
    This is required in cases where later in a session, the model is wrapped again in FSDP but
    contains nested FSDP wrappers within the module.
    """
    if torch_version() < (1, 6, 0):
        pytest.skip("older pytorch doesn't support reduce_scatter")

    temp_file_name = tempfile.mkstemp()[1]
    unused = tempfile.mkstemp()[1]

    fsdp_config = {
        "mixed_precision": precision == "mixed",
        "flatten_parameters": flatten == "flatten",
    }

    mp.spawn(
        _test_func,
        args=(world_size, fsdp_config, temp_file_name, unused),
        nprocs=world_size,
        join=True,
    )
Exemplo n.º 17
0
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import functools
import tempfile

import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from fairscale.nn import MOELayer, Top2Gate
from fairscale.utils import torch_version

pytestmark = pytest.mark.skipif(
    not (torch.cuda.is_available() and torch_version() >= (1, 8, 0)),
    reason="cuda and torch>=1.8.0 required")

devices = ["cuda"]


def pg_worker(rank, world_size, init_file, func, *args):
    init_url = "file://" + init_file
    dist.init_process_group(backend=dist.Backend.NCCL,
                            rank=rank,
                            world_size=world_size,
                            init_method=init_url)
    torch.cuda.set_device(rank)
    dist.all_reduce(torch.zeros(1).cuda())
    func(*args)
    dist.destroy_process_group()
Exemplo n.º 18
0
import unittest

from parameterized import parameterized
import pytest
import torch
from torch import nn
import torch.distributed

from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
from fairscale.nn.data_parallel import FullyShardedDataParallel, OffloadConfig, TrainingState
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, spawn_for_all_world_sizes

# Note: We need the nightly version for SSD offload to work. Hence I am checking for the next PyTorch release.
print(f"torch version {torch_version()}")
pytestmark = pytest.mark.skipif(torch_version() < (1, 11, 0),
                                reason="requires torch version >= 1.11.0")

# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# All helper functions called by spawn must be either @classmethod, @staticmethod


class DistributedTest(unittest.TestCase):
    def setUp(self):
        if not torch.cuda.is_available():
            raise unittest.SkipTest("CUDA not available, skipping test")
        if sys.platform == "win32":
            raise unittest.SkipTest(
                "NCCL doesn't support Windows, skipping test")
        if torch.cuda.device_count() < 2:
            raise unittest.SkipTest(
Exemplo n.º 19
0
def check_pytorch_version() -> None:
    if torch_version() < (1, 9, 0):
        raise Exception("DistributedPipeline requires PyTorch version 1.9 or higher")
Exemplo n.º 20
0
    def __init__(
        self,
        module: nn.Sequential,
        balance: Optional[Iterable[int]] = None,
        *,
        devices: Optional[Devices] = None,
        chunks: int = chunks,
        checkpoint: str = checkpoint,
        deferred_batch_norm: bool = False,
    ) -> None:
        super().__init__()

        if torch_version()[:2] >= (1, 8):
            warnings.warn(
                "fairscale.nn.Pipe has been upstreamed to PyTorch as torch.distributed.pipeline.sync.Pipe. "
                "It is now deprecated and will be removed in a future version of fairscale. "
                "The PyTorch API has minor changes. Please see https://pytorch.org/docs/stable/pipeline.html for details.",
                DeprecationWarning,
            )

        chunks = int(chunks)
        checkpoint = str(checkpoint)

        if balance is None:
            raise ValueError(recommend_auto_balance("balance is required"))
        if chunks <= 0:
            raise ValueError("number of chunks must be positive integer")
        if checkpoint not in ["always", "except_last", "never"]:
            raise ValueError(
                "checkpoint is not one of 'always', 'except_last', or 'never'")

        verify_module(module)

        # Verify if the underlying skippable modules satisfy integrity. The
        # integrity can be verified before forward() because it is static.
        verify_skippables(module)

        self.chunks = chunks
        self.checkpoint = checkpoint

        if deferred_batch_norm:
            module = DeferredBatchNorm.convert_deferred_batch_norm(
                module, chunks)

        if devices is None:
            devices = range(torch.cuda.device_count())
        devices = [torch.device(d) for d in devices]
        devices = cast(List[torch.device], devices)

        try:
            self.partitions, self.balance, self.devices = split_module(
                module, balance, devices)
        except BalanceError as exc:
            raise ValueError(recommend_auto_balance(str(exc)))

        verify_splitting(module, self.partitions, self.balance, self.devices)

        self._copy_streams: List[List[AbstractStream]] = []
        self._skip_layout = inspect_skip_layout(self.partitions)

        # Separate CUDA streams for copy.
        copy_streams = self._ensure_copy_streams()

        # The micro-batch index where the checkpointing stops.
        checkpoint_stop = {
            "always": self.chunks,
            "except_last": self.chunks - 1,
            "never": 0
        }[self.checkpoint]

        self.pipeline = Pipeline(self.partitions, self.devices, copy_streams,
                                 self._skip_layout, checkpoint_stop)
Exemplo n.º 21
0
    def run(compute_cycles, all_gather_cycles):
        has_params = all_gather_cycles > 0
        model = _create_model(fsdp_config, compute_cycles, has_params)

        # Get the input and sets the input's requires_grad to True because
        # we have a fake compute in the forward pass.
        batch = torch.rand(1).cuda()
        batch.requires_grad = True

        # We run 20 iterations but only collect timing data from the minimal 10
        # data points because nondeterministic system events can disturb the timing.
        cpu_iter = Min10()
        cpu_wait = Min10()
        gpu_compute = Min10()
        gpu_total = Min10()
        for _ in range(20):
            # Get two events for measuring the overall time.
            e1 = Event(enable_timing=True)
            e2 = Event(enable_timing=True)

            cpu_start = time.process_time()

            all_gather_called = False
            all_gather_base_called = False

            def _delayed_all_gather(*args, **kwargs):
                nonlocal all_gather_called
                all_gather_called = True
                torch.cuda._sleep(all_gather_cycles)
                return orig_all_gather(*args, **kwargs)

            def _delayed_all_gather_base(*args, **kwargs):
                nonlocal all_gather_base_called
                all_gather_base_called = True
                torch.cuda._sleep(all_gather_cycles)
                assert orig_all_gather_base
                return orig_all_gather_base(*args, **kwargs)

            method_string_all_gather_base = "torch.distributed._all_gather_base"
            if hasattr(torch.distributed, "_all_gather_base") is False:
                # no such method, to make mock_all_gather_base 0 invocation, use an impossible name
                method_string_all_gather_base = "math.nan"
                pass
            # forward pass
            #
            # Even though both e1 & e2 are on the compute stream, since
            # compute depends on all_gather, e2-e1 includes all_gather time.
            e1.record()
            with patch("torch.distributed.all_gather", _delayed_all_gather):
                with patch(method_string_all_gather_base,
                           _delayed_all_gather_base):
                    out = model(batch)
                    if has_params and world_size > 1:
                        assert all_gather_called or all_gather_base_called
                    else:
                        assert not all_gather_called and not all_gather_base_called
            e2.record()

            # backward pass
            out.backward()
            if torch_version() >= (1, 7, 0):
                model.zero_grad(set_to_none=True)
            else:
                for p in model.parameters():
                    p.grad = None

            cpu_iter_time = time.process_time() - cpu_start

            # wait for gpu
            out.item()
            cpu_wait_for_gpu_time = time.process_time(
            ) - cpu_start - cpu_iter_time

            # get sum of the compute time
            times = []
            for mod in model.modules():
                if not isinstance(mod, Layer):
                    continue
                times.append(mod.get_time())

            # get gpu compute + all_gather time
            overall_gpu_time = e1.elapsed_time(e2)

            cpu_iter.add(cpu_iter_time)
            cpu_wait.add(cpu_wait_for_gpu_time)
            gpu_compute.add(sum(times))
            gpu_total.add(overall_gpu_time)

        del model
        return {
            "cpu_iter": cpu_iter.avg(),
            "cpu_wait": cpu_wait.avg(),
            "gpu_compute": gpu_compute.avg(),
            "gpu_total": gpu_total.avg(),
        }
Exemplo n.º 22
0
def run_ddp_parity(
    rank,
    world_size,
    backend,
    temp_file_name,
    reduce_buffer_size,
    grad_accumulation,
    change_train_graph,
    fp16_reduction,
    clip_grad_norm,
    amp,
    manual_reduction,
    multiple_fw,
):
    dist.init_process_group(init_method="file://" + temp_file_name,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)

    device = torch.device("cuda")
    torch.cuda.set_device(rank)
    torch.manual_seed(rank)
    np.random.seed(rank)
    NUMBER_BATCHS = 5

    # Test all combinations: AMP, Accumulate, Change train graph, reduce buckets
    print(
        f"{rank}: Checking configuration: accumulate {grad_accumulation}" +
        f" - change train graph {change_train_graph}" + f" - amp {amp}" +
        f" - manual reduction {manual_reduction}" +
        f" - buffers {reduce_buffer_size}" + f" - multiple FW {multiple_fw}",
        flush=True,
    )

    # The API should be the exact same in between the sharded and non-sharded variants, generic closure
    def closure(model,
                scaler,
                input_tensor,
                should_accumulate,
                _manual_reduction=False):
        accumulate_steps = 3 if should_accumulate else 1

        model.zero_grad()

        def step():
            if scaler is not None:
                with torch.cuda.amp.autocast():
                    loss = model(input_tensor).abs().sum()
                    scaler.scale(loss).backward()
            else:
                loss = model(input_tensor).abs().sum()
                loss.backward()

        with model.no_sync() if should_accumulate else suppress():
            for _ in range(accumulate_steps - 1):
                step()

        if not _manual_reduction:
            step()
        else:
            with model.no_sync():
                step()

            model.reduce()

    # Any model works. Add one different buffer per rank
    model = _get_mlp_emb(multiple_fw)
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)

    # Make sure that the model starts with non-trainable, so that we check for the buckets to be
    # properly reassigned when/if this changes
    next(model.parameters()).requires_grad = False

    sharded_optimizer = OSS(params=model.parameters(),
                            optim=torch.optim.SGD,
                            lr=1e-4,
                            momentum=0.99)
    sharded_ddp_model = ShardedDataParallel(
        module=model,
        sharded_optimizer=sharded_optimizer,
        broadcast_buffers=True,
        reduce_buffer_size=reduce_buffer_size,
        reduce_fp16=fp16_reduction,
    )

    ddp_model_single = copy.deepcopy(model)
    ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(),
                                    lr=1e-4,
                                    momentum=0.99)
    ddp_model = DDP(ddp_model_single,
                    device_ids=[rank],
                    broadcast_buffers=True,
                    find_unused_parameters=True)

    if fp16_reduction:
        from dist.algorithms.ddp_com_hooks.default_hooks import fp16_compress_hook

        ddp_model.register_comm_hook(state=None,
                                     hook=fp16_compress_hook)  # type: ignore

    ddp_scaler = TorchGradScaler() if amp else None
    sharded_scaler = ShardedGradScaler() if amp else None

    # The model should be synchronized in between the ranks at construction time, check that
    check_same_model_params(sharded_ddp_model, ddp_model)

    # Typical training loop, check that we get the exact same results as DDP
    for i in range(NUMBER_BATCHS):
        input_tensor = _get_random_inputs(device)

        def ddp_closure(input_tensor=input_tensor):
            return closure(ddp_model, ddp_scaler, input_tensor,
                           grad_accumulation)

        def sharded_closure(input_tensor=input_tensor):
            return closure(
                sharded_ddp_model,
                sharded_scaler,
                input_tensor,
                grad_accumulation,
                _manual_reduction=manual_reduction,
            )

        # Step/scale both
        for _scaler, _closure, _optimizer in (
            (ddp_scaler, ddp_closure, ddp_optimizer),
            (sharded_scaler, sharded_closure, sharded_optimizer),
        ):
            if _scaler is not None:
                _ = _closure(input_tensor)
                _scaler.step(_optimizer)
                _scaler.update()
            else:
                _optimizer.step(_closure())

        check_same_model_params(sharded_ddp_model, ddp_model,
                                f"Rank: {rank} - Step {i} broke")

        # Check that the two grad norm are equivalent
        # NOTE: The grads can occasionally be NaNs, the scaler will skip the step in that case
        # This is not ShardedDDP specific. If the grads are not NaN for DDP then they should also
        # be valid for ShardedDDP
        # NOTE: DDP does not handle parameters trainability being changed after the fact, see
        # https://github.com/pytorch/pytorch/blob/5781aec74ef00284e0262817a649278c2e8072bf/torch/nn/parallel/distributed.py#L471
        if clip_grad_norm and not change_train_graph:
            if torch_version() >= (1, 9, 0):
                total_norm = torch.nn.utils.clip_grad_norm_(
                    ddp_model.parameters(),
                    0.3,
                    norm_type=2.0,
                    error_if_nonfinite=False)  # type: ignore
            else:
                total_norm = torch.nn.utils.clip_grad_norm_(
                    ddp_model.parameters(), 0.3, norm_type=2.0)  # type: ignore
            if not torch.isnan(total_norm):
                oss_total_norm = sharded_optimizer.clip_grad_norm(
                    0.3, norm_type=2.0)
                allclose = torch.allclose(oss_total_norm,
                                          total_norm,
                                          atol=1e-2 if amp else 1e-8)

                if not allclose:
                    # Debug helper if this unit test does not pass, compare the gradients in between DDP and ShardedDDP
                    for idx, (p_ddp, p_sdp) in enumerate(
                            zip(ddp_model.parameters(),
                                sharded_ddp_model.parameters())):
                        if p_ddp.grad is not None:
                            if p_sdp.grad is not None:
                                print(rank,
                                      idx,
                                      torch.norm(p_ddp.grad),
                                      torch.norm(p_sdp.grad),
                                      flush=True)
                            else:
                                print(rank,
                                      idx,
                                      torch.norm(p_ddp.grad),
                                      "not owned",
                                      flush=True)

                assert (
                    allclose
                ), f"torch and fairscale should return the same grad norm\n {oss_total_norm} vs {total_norm}"
            else:
                print(rank, "NaN grad norm in DDP", flush=True)

        # Flip the trainability of the first parameter back and forth
        if i == 0 and change_train_graph:
            next(sharded_ddp_model.parameters()).requires_grad = not next(
                sharded_ddp_model.parameters()).requires_grad
            next(ddp_model.parameters()).requires_grad = not next(
                ddp_model.parameters()).requires_grad
            check_same_model_params(
                sharded_ddp_model, ddp_model,
                f"Rank: {rank} - Trainability refresh {i} broke")

    dist.destroy_process_group()
Exemplo n.º 23
0
import numpy as np
import pytest
import torch
from torch.cuda.amp import GradScaler as TorchGradScaler
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import Linear, Sequential
from torch.nn.parallel import DistributedDataParallel as DDP

from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS
from fairscale.utils import torch_version
from fairscale.utils.testing import check_same_model_params, skip_if_no_cuda, skip_if_single_gpu, temp_files_ctx

if torch_version() >= (1, 8, 0):
    from fairscale.optim.grad_scaler import ShardedGradScaler
"""
Check that ShardedDDP gets the same results as DDP in a variety of scenarii
"""

_test_fp16_reduction = [False]

if hasattr(dist, "algorithms.ddp_com_hooks.default_hooks"):
    _test_fp16_reduction.append(True)

_test_amp = [False]
if hasattr(torch.cuda.amp, "autocast"):
    _test_amp.append(True)

EMB_SIZE = 32
Exemplo n.º 24
0
def dist_init(rank: int,
              world_size: int,
              filename: str,
              filename_rpc: str = "") -> bool:
    """
    Initialize torch distributed, based on a temporary file shared across ranks, which makes it possible for unrelated
    tests to be run concurrently.

    Return false if not enough GPUs present in the system.

    .. warning: This limits the usecase to all ranks being on the same node
    """

    try:
        torch.distributed.rpc.shutdown()
    except Exception:
        pass

    print(f"dist init r={rank}, world={world_size}")

    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["RANK"] = str(rank)
    url = "file://" + filename
    url_rpc = "file://" + filename_rpc

    if torch_version() >= (1, 6, 0):
        backend = "nccl" if torch.cuda.is_available() else "gloo"
        if backend == "nccl" and torch.cuda.device_count() < world_size:
            logging.warning(
                "Requested world size cannot be reached on this machine, not enough GPUs"
            )
            return False

        torch.distributed.init_process_group(backend=backend,
                                             rank=rank,
                                             world_size=world_size,
                                             init_method=url)

        tp_options = {"init_method": url_rpc}
        # Workaround for bug in torch v1.8.0. Should be fixed in v1.8.1
        if torch_version() == (1, 8, 0):
            if torch.cuda.is_available():
                # Workaround for https://github.com/pytorch/pytorch/issues/53844
                tp_options["_transports"] = ["ibv", "uv"]  # type: ignore
            else:
                # Workaround for https://github.com/pytorch/pytorch/issues/54266
                tp_options["_channels"] = [
                    "mpt_uv", "basic", "cuda_ipc", "cuda_gdr", "cuda_xth",
                    "cuda_basic"
                ]  # type: ignore

        rpc.init_rpc(
            f"Test{rank}",
            rank=rank,
            world_size=world_size,
            backend=rpc.BackendType.TENSORPIPE,
            rpc_backend_options=rpc.TensorPipeRpcBackendOptions(**tp_options),
        )

    else:
        if world_size > 1:
            # TensorPipe is not available in Torch 1.5
            rpc.init_rpc(
                name=f"Test{rank}",
                rank=rank,
                world_size=world_size,
                rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(
                    init_method=url_rpc),
            )
        elif torch.cuda.is_available():
            torch.distributed.init_process_group(backend="nccl",
                                                 rank=rank,
                                                 world_size=world_size,
                                                 init_method=url)
        else:
            return False

    if torch.cuda.is_available() and torch.cuda.device_count():
        torch.cuda.set_device(rank % torch.cuda.device_count())

    return True
Exemplo n.º 25
0
from typing import Any, Dict, List, NamedTuple, Tuple

import pytest
import torch
import torch.distributed.autograd as dist_autograd
from torch.distributed.nn import RemoteModule
from torch.distributed.optim import DistributedOptimizer
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import torch.nn as nn

from fairscale.experimental.nn.distributed_pipeline import DistributedLoss, DistributedPipeline, PipelineModulesGraph
from fairscale.utils import torch_version

pytestmark = pytest.mark.skipif(
    not torch.cuda.is_available() or torch_version() < (1, 9, 0),
    reason="CPU tests fail right now and all tests require torch version >= 1.9.0.",
)

CPU_DEVICES = ["worker0/cpu", "worker1/cpu"]
GPU_DEVICES = ["worker0/cuda:0", "worker1/cuda:1"]
if torch.cuda.is_available():
    DEVICES = [CPU_DEVICES, GPU_DEVICES]
else:
    DEVICES = [CPU_DEVICES]


def rpc_worker(rank, world_size, init_file, func, *args):
    options = rpc.TensorPipeRpcBackendOptions(init_method="file://" + init_file)
    for i in range(world_size):
        options.set_device_map("worker" + str(i), {rank: i})
Exemplo n.º 26
0
def test_multiple_forward_checkpoint(precision, flatten, wrap_bn, model_type,
                                     bn_type):
    mixed_precision = precision == "mixed"
    flatten = flatten == "flatten"
    wrap_bn = wrap_bn == "auto_wrap_bn"
    fp32_reduce_scatter = True if mixed_precision else None
    with_model2 = model_type == "model2"
    with_sync_bn = bn_type == "sync_bn"

    if torch_version() >= (1, 7, 0) and torch_version() < (1, 8,
                                                           0) and with_sync_bn:
        # SyncBN is buggy in 1.7, errors like:
        # E         File "/home/circleci/venv/lib/python3.8/site-packages/torch/nn/modules/_functions.py", line 13, in forward
        # E           dtype=running_mean.dtype,
        # E       AttributeError: 'NoneType' object has no attribute 'dtype'
        pytest.skip("SyncBatchNorm in 1.7 is buggy")

    if with_sync_bn and not wrap_bn:
        pytest.skip("SyncBatchNorm requires auto_wrap_bn")

    if torch_version() < (1, 8, 0) and flatten:
        # 1.6 and 1.7 throws this error:
        #   RuntimeError: Trying to backward through the graph a second time, but the saved
        #   intermediate results have already been freed. Specify retain_graph=True when calling
        #   backward the first time.
        pytest.skip("older pytorch throws error when flatten is used")

    world_size = 2
    expected_losses = None
    # Ensure ddp == ddp+ckpt == fsdp == fsdp+ckpt.
    for with_fsdp in [False, True]:
        for with_checkpoint in [False, True]:
            if not with_fsdp and with_checkpoint:
                continue
            final_losses = _get_cached_results(
                world_size,
                with_model2,
                with_sync_bn,
                with_fsdp,
                with_checkpoint,
                mixed_precision,
                flatten,
                wrap_bn,
                fp32_reduce_scatter,
            )
            if expected_losses is None:
                expected_losses = final_losses
            else:
                print(
                    f"checking: fsdp {with_fsdp} ckpt {with_checkpoint} with ddp+no_ckpt"
                )

                def check(exp, res):
                    assert list(exp.keys()) == list(res.keys(
                    )), f"{list(exp.keys())} vs. {list(res.keys())}"
                    rtol = 1e-4
                    atol = 1e-5
                    if with_model2 and mixed_precision and torch_version() >= (
                            1, 9, 0):
                        # On CI, with longer model2, mixed precsion and 1.9, even ddp vs. ddp+ckpt has
                        # larger errors.
                        rtol = 1e-3
                        atol = 1e-4
                    for key in exp.keys():
                        exp_loss = exp[key]
                        res_loss = res[key]
                        torch.testing.assert_allclose(exp_loss,
                                                      res_loss,
                                                      rtol=rtol,
                                                      atol=atol)

                check(expected_losses, final_losses)
def test_multiple_forward_checkpoint(precision, flatten, wrap_bn, model_type,
                                     bn_type):
    mixed_precision = precision == "mixed"
    flatten = flatten == "flatten"
    wrap_bn = wrap_bn == "auto_wrap_bn"
    fp32_reduce_scatter = True if mixed_precision else None
    with_model2 = model_type == "model2"
    with_sync_bn = bn_type == "sync_bn"

    if torch_version() >= (1, 7, 0) and torch_version() < (1, 8,
                                                           0) and with_sync_bn:
        # SyncBN is buggy in 1.7, errors like:
        # E         File "/home/circleci/venv/lib/python3.8/site-packages/torch/nn/modules/_functions.py", line 13, in forward
        # E           dtype=running_mean.dtype,
        # E       AttributeError: 'NoneType' object has no attribute 'dtype'
        pytest.skip("SyncBatchNorm in 1.7 is buggy")

    if with_sync_bn and not wrap_bn:
        pytest.skip("SyncBatchNorm requires auto_wrap_bn")

    if torch_version() < (1, 8, 0) and flatten:
        # 1.6 and 1.7 throws this error:
        #   RuntimeError: Trying to backward through the graph a second time, but the saved
        #   intermediate results have already been freed. Specify retain_graph=True when calling
        #   backward the first time.
        pytest.skip("older pytorch throws error when flatten is used")

    world_size = 2
    expected_losses = None

    # Ensure ddp == fsdp when modules are called multiple times per forward pass with/without checkpointing, forward
    # counters and reducer bucketing.
    #
    # The bucketing check exists because the asynchronous gradient reduction it induces can interact with multiple
    # forward passes in complex ways. For example, in the midst of a sharded backward pass, `parameter.grad` may only be
    # `None` or an unsharded gradient tensor. The sharded tensor is then set at the end of the backwards pass. But a
    # unit test with bucketing enabled might not catch violations of this invariant. For very small models, like the
    # kind used in this unit test, bucketing will delay gradient reduction until after all the gradient computation is
    # done. If the reduction incorrectly sets `.grad` to the _sharded_ variant, the test might not fail, since the
    # gradient computations have already happened. Toggling bucketing helps verify that gradient reduction and
    # computation interact correctly.
    combinations = []
    for with_fsdp in [False, True]:
        for with_checkpoint in [False, True]:
            if not with_fsdp and with_checkpoint:
                continue
            for with_bucketing in [False, True]:
                if not with_fsdp and with_bucketing:
                    continue
                combinations.append(
                    (with_fsdp, with_checkpoint, with_bucketing))
    print("")
    print("Testing the following configurations:")
    for with_fsdp, with_checkpoint, with_bucketing in combinations:
        print(
            f"  fsdp {with_fsdp} ckpt {with_checkpoint} bucketing {with_bucketing}"
        )

    for with_fsdp, with_checkpoint, with_bucketing in combinations:
        if with_bucketing:
            bucket_cap_mb = 25
        else:
            bucket_cap_mb = 0
        final_losses = _get_cached_results(
            world_size,
            with_model2,
            with_sync_bn,
            with_fsdp,
            with_checkpoint,
            mixed_precision,
            flatten,
            wrap_bn,
            fp32_reduce_scatter,
            bucket_cap_mb,
        )
        if expected_losses is None:
            expected_losses = final_losses
        else:
            print(
                f"checking: fsdp {with_fsdp} ckpt {with_checkpoint} bucketing {with_bucketing} with ddp+no_ckpt"
            )

            def check(exp, res):
                assert list(exp.keys()) == list(
                    res.keys()), f"{list(exp.keys())} vs. {list(res.keys())}"
                rtol = 1e-4
                atol = 1e-5
                if with_model2 and mixed_precision and torch_version() >= (
                        1, 9, 0):
                    # On CI, with longer model2, mixed precsion and 1.9, even ddp vs. ddp+ckpt has
                    # larger errors.
                    rtol = 1e-3
                    atol = 1e-4
                for key in exp.keys():
                    exp_loss = exp[key]
                    res_loss = res[key]
                    torch.testing.assert_allclose(exp_loss,
                                                  res_loss,
                                                  rtol=rtol,
                                                  atol=atol)

            check(expected_losses, final_losses)
Exemplo n.º 28
0
def _distributed_worker(gpu_id, world_size, with_fsdp, with_checkpoint,
                        filename, filename_rpc, expected, model_hidden_dim,
                        fsdp_config):
    torch.cuda.set_device(gpu_id)

    rank = gpu_id
    result = dist_init(rank, world_size, filename, filename_rpc)
    assert result, "Dist init failed"

    torch.manual_seed(0)
    torch.backends.cudnn.deterministic = True

    # Note that FSDP auto-cast the input in AMP mode. So we don't need to call half() here.
    batch = torch.randn(size=(2, 3, 224, 224)).cuda()

    model = create_model(with_fsdp, with_checkpoint, model_hidden_dim,
                         fsdp_config)
    model = model.cuda()
    if with_fsdp:
        model = to_fsdp(model, fsdp_config)
    else:
        model = DistributedDataParallel(model,
                                        device_ids=[gpu_id],
                                        bucket_cap_mb=500)

    # We enable momentum so that after the first iteration, the optimizer state is added
    # to the total memory used.
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)

    # Set AMP context if needed.
    context = contextlib.suppress()
    if "mixed_precision" in fsdp_config and fsdp_config["mixed_precision"]:
        context = torch.cuda.amp.autocast(enabled=True)

    # We have observed that sometimes after 3rd iteration, 4th one can fail (not on this
    # test but on much bigger scale tests). We run 4 iterations here just in case it happens.
    iterations = 4

    results = {}  # results of memory stats
    for iteration in range(iterations):
        get_cur_mem(gpu_id, results, f"iter {iteration}: start")

        with context:
            out = model(batch)
            get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd")

            out = sum(o.sum() for o in out[0])
            fake_loss = criterion(out, torch.tensor(0.0).cuda())
            get_cur_mem(gpu_id, results, f"iter {iteration}: after loss")

        fake_loss.backward()
        get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd")

        optimizer.step()
        get_cur_mem(gpu_id, results, f"iter {iteration}: after step")

        # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory.
        if torch_version() >= (1, 7, 0):
            model.zero_grad(set_to_none=True)
        else:
            for p in model.parameters():
                p.grad = None
        get_cur_mem(gpu_id, results, f"iter {iteration}: done")

    dump_all_tensors(gpu_id)
    print(results)

    def cmp(results, expected):
        ret = ""
        assert results.keys() == expected.keys(
        ), f"{list(results.keys())} vs. {list(expected.keys())}"
        for k, v in results.items():
            exp = expected[k]
            if abs(exp - v) > 1:  # allow 1MB rounding differences
                ret += f"{k}: got {v}, expected {exp}\n"
        return ret

    output = cmp(results, expected)
    assert not output, output

    teardown()
Exemplo n.º 29
0
    all_reduce_handle = dist.all_reduce(total_count,
                                        group=process_group,
                                        async_op=True)
    mean = torch.mean(input, dim=dim, keepdim=True)
    meansqr = torch.mean(input * input, dim=dim, keepdim=True)
    vec = torch.cat([mean, meansqr])
    all_reduce_handle.wait()
    vec = vec * (count / total_count)
    dist.all_reduce(vec, group=process_group)
    mean, meansqr = vec.chunk(2)
    var = meansqr - mean * mean
    invstd = torch.rsqrt(var + eps)
    return mean, var, invstd, total_count


if torch_version()[:2] >= (1, 7):
    _forward = torch.jit.script(_forward)  # type: ignore
    _track_running_stats = torch.jit.script(
        _track_running_stats)  # type: ignore


class _SyncBatchNormFunction(torch.autograd.Function):
    """
    An autograd function used to avoid storing activations for intermediate results.

    NOTE: Even though the mean and var are passed into this function, we do the entire
    backward, including mean and var, here. We have to calculate statistics outside
    this function in order to avoid multiple all_reduces when using checkpointing.
    """
    @staticmethod
    # type: ignore
Exemplo n.º 30
0
@torch_spawn([3])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ,
                    reason="mpi required")
def rpc_multiple_tensors():
    class FuseTwo(nn.Module):
        def forward(self, left, right):
            return left + right

    class SplitTwo(nn.Module):
        def forward(self, inputs):
            return (inputs, 2 * inputs)


@torch_spawn([2])
@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="no mpi")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
# TODO(msb) Fix this
@pytest.mark.skipif(torch_version() >= (1, 8, 0),
                    reason="disabled for torch 1.8.0")
def construct_only_rank_zero():
    model = [nn.Linear(10, 10), nn.ReLU()]
    if torch.distributed.get_rank() == 0:
        PipeRPCWrapper(model, [1, 1], worker_map=get_worker_map())
        rpc.shutdown()
    else:
        # Must enter rpc loop to complte PipeRPCWrapper constructor above
        rpc.shutdown()

        with pytest.raises(AssertionError):
            PipeRPCWrapper(model, [1, 1], worker_map=get_worker_map())