Ejemplo n.º 1
0
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
    dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)
    model = _DoubleInput().to(device)

    parameters = list(model.parameters())
    optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
    ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2])

    # Optim loop
    def closure():
        input_tensor = torch.rand((64, 2)).to(device)
        loss = ddp_model(input_tensor, input_tensor).abs().sum()
        loss.backward()
        return loss

    for i in range(5):
        optimizer_1.zero_grad()
        optimizer_2.zero_grad()

        _ = optimizer_1.step(closure=closure)
        _ = optimizer_2.step(closure=closure)

    dist.destroy_process_group()
Ejemplo n.º 2
0
def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
    device = torch.device("cuda")
    torch.cuda.set_device(rank)
    torch.manual_seed(rank)
    np.random.seed(rank)  # Any model works. Add one different buffer per rank

    model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)
    n_half_params = len(list(model.parameters())) // 2

    sharded_optimizer = OSS(
        params=list(model.parameters())[:n_half_params], optim=torch.optim.SGD, lr=1e-3, momentum=0.99
    )
    sharded_optimizer_2 = OSS(
        params=list(model.parameters())[n_half_params:], optim=torch.optim.SGD, lr=1e-3, momentum=0.99
    )

    sharded_ddp_model = ShardedDataParallel(module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True)

    ddp_model_single = copy.deepcopy(model)
    ddp_optimizer = torch.optim.SGD(list(ddp_model_single.parameters())[:n_half_params], lr=1e-3, momentum=0.99)
    ddp_optimizer_2 = torch.optim.SGD(list(ddp_model_single.parameters())[n_half_params:], lr=1e-3, momentum=0.99)
    ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)

    def check_same_model_params():
        for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
            for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
                assert torch.allclose(
                    p, ddp_p, atol=1e-3
                ), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}"
        for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
            assert torch.allclose(b, ddp_b, atol=1e-3), "Model buffers differ in between DDP and ShardedDDP"

    check_same_model_params()  # The models should stay the same in between the ranks

    for i in range(20):
        input_tensor = torch.rand((64, 2)).to(device)

        # Run DDP
        ddp_optimizer.zero_grad()
        ddp_optimizer_2.zero_grad()
        ddp_loss = ddp_model(input_tensor).abs().sum()
        ddp_loss.backward()
        ddp_optimizer.step()
        ddp_optimizer_2.step()

        # Run Sharded
        sharded_optimizer.zero_grad()
        sharded_optimizer_2.zero_grad()
        sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
        sharded_loss.backward()
        sharded_optimizer.step()
        sharded_optimizer_2.step()
        check_same_model_params()

    dist.destroy_process_group()
Ejemplo n.º 3
0
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)

    class _DoubleInput(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.mlp = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3),
                                  Linear(3, 3), Linear(3, 3), Linear(3, 3))

        def forward(self, x, y):
            x1 = self.mlp(x)
            x2 = self.mlp(y)
            return torch.cat((x1, x2), dim=1)

    model = _DoubleInput().to(device)

    parameters = list(model.parameters())
    optimizer_1 = OSS(params=parameters[:-10],
                      optim=torch.optim.SGD,
                      lr=0.01,
                      momentum=0.99)
    optimizer_2 = OSS(params=parameters[-10:],
                      optim=torch.optim.SGD,
                      lr=0.01,
                      momentum=0.99)
    ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2])

    # Optim loop
    def closure():
        input_tensor = torch.rand((64, 2)).to(device)
        loss = ddp_model(input_tensor, input_tensor).abs().sum()
        loss.backward()
        return loss

    for i in range(5):
        optimizer_1.zero_grad()
        optimizer_2.zero_grad()

        _ = optimizer_1.step(closure=closure)
        _ = optimizer_2.step(closure=closure)

    dist.destroy_process_group()
Ejemplo n.º 4
0
def run_one_step(rank, world_size, backend, device, temp_file_name):
    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    model = Sequential(Linear(2, 3), Linear(3, 4)).to(device)
    optimizer = OSS(model.parameters(), lr=0.1, momentum=0.99)
    ddp = OssDdp(model, optimizer, world_size)
    input_tensor = torch.rand((64, 2)).to(device)
    output = ddp(input_tensor).sum()
    output.backward()
    ddp.reduce()
    optimizer.step()
def run_test_two_inputs(rank, world_size, backend, device, temp_file_name,
                        reduce_buffer_size):
    dist.init_process_group(init_method="file://" + temp_file_name,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)
    if device == "cuda":
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)

    model = _DoubleInput().to(device)
    optimizer = OSS(params=model.parameters(),
                    optim=torch.optim.SGD,
                    lr=1e-3,
                    momentum=0.99)
    ddp_model = ShardedDataParallel(model,
                                    optimizer,
                                    reduce_buffer_size=reduce_buffer_size)

    # Optim loop
    def closure():
        optimizer.zero_grad()
        input_tensor = torch.rand((64, 2)).to(device)
        loss = ddp_model(input_tensor, input_tensor).abs().sum()
        loss.backward()
        return loss

    for i in range(5):
        _ = optimizer.step(closure=closure)

    dist.destroy_process_group()
Ejemplo n.º 6
0
def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
    INPUT_DIM = 32
    BACH_SIZE = 10
    STEPS = 10

    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)
    model = GPT2(
        embed_dim=512, num_heads=2, num_layers=24, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
    ).to(device)
    optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
    ddp_model = ShardedDataParallel(model, optimizer)

    # Optim loop
    def closure():
        optimizer.zero_grad()
        # Force int inputs to prevent the first grad from firing
        input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device)
        loss = ddp_model(input_tensor).abs().sum()
        loss.backward()
        return loss

    # Check for bucketing overflows
    for i in range(STEPS):
        _ = optimizer.step(closure=closure)

    dist.destroy_process_group()
Ejemplo n.º 7
0
def run_test_gpt2(rank, world_size, backend, device, temp_file_name,
                  reduce_buffer_size):
    INPUT_DIM = 16
    BACH_SIZE = 10
    STEPS = 10

    url = "file://" + temp_file_name
    dist.init_process_group(init_method=url,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)
    torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)
    model = GPT2(embed_dim=256,
                 num_heads=2,
                 num_layers=12,
                 num_positions=INPUT_DIM * INPUT_DIM,
                 num_vocab=512,
                 num_classes=2)
    optimizer = OSS(params=model.parameters(),
                    optim=torch.optim.SGD,
                    lr=1e-3,
                    momentum=0.99)
    ddp_model = ShardedDataParallel(model,
                                    optimizer,
                                    reduce_buffer_size=reduce_buffer_size)

    # Move the model to another device post-construction
    model = model.to(device)

    # Optim loop
    set_to_none = True

    def closure():
        nonlocal set_to_none
        ddp_model.zero_grad(set_to_none=set_to_none)
        set_to_none = not set_to_none

        # Force int inputs to prevent the first grad from firing
        input_tensor = torch.randint(10, (BACH_SIZE, INPUT_DIM)).to(device)
        loss = ddp_model(input_tensor).abs().sum()
        loss.backward()
        return loss

    # Check for bucketing overflows
    for i in range(STEPS):
        _ = optimizer.step(closure=closure)

        # Stress test the .to() method
        ddp_model.to(device=device, dtype=torch.float16)
        ddp_model.to(device=device, dtype=torch.float32)

    dist.destroy_process_group()
def run_one_step(
    rank,
    world_size,
    backend,
    device,
    temp_file_name,
    broadcast_buffers,
    grad_accumulation,
    reduce_buffer_size,
):
    dist.init_process_group(init_method="file://" + temp_file_name,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)
    if device == torch.device("cuda"):
        torch.cuda.set_device(rank)

    torch.manual_seed(rank)
    np.random.seed(rank)

    # Any model works. Add one different buffer per rank
    model = _get_mlp()
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)

    next(model.parameters()
         ).requires_grad = False  # Test non-trainable parameters

    optimizer = OSS(params=model.parameters(),
                    optim=torch.optim.SGD,
                    lr=1e-3,
                    momentum=0.99)
    ddp_model = ShardedDataParallel(model,
                                    optimizer,
                                    broadcast_buffers=broadcast_buffers,
                                    reduce_buffer_size=reduce_buffer_size)

    # The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
    check_same_models_across_ranks(ddp_model,
                                   dist.group.WORLD,
                                   params_should_be_equal=True,
                                   check_broadcast_buffers=broadcast_buffers)

    # Optim loop
    def closure():
        optimizer.zero_grad()

        with ddp_model.no_sync() if grad_accumulation else suppress():
            input_tensor = torch.rand((64, 2)).to(device)
            loss = ddp_model(input_tensor).abs().sum()
            loss.backward()
        return loss

    # The models should stay the same in between the ranks
    for i in range(5):
        _ = optimizer.step(closure=closure)
        # when running on cpu/gloo the "nodes" are not really different
        same_params = device == torch.device("cpu") or grad_accumulation
        check_same_models_across_ranks(
            ddp_model,
            dist.group.WORLD,
            params_should_be_equal=same_params,
            check_broadcast_buffers=broadcast_buffers)

    dist.destroy_process_group()
Ejemplo n.º 9
0
    def check(broadcast_buffers: bool,
              grad_accumulation: bool = False) -> None:
        # Any model works. Add one different buffer per rank
        model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3),
                           Linear(3, 3), Linear(3, 3), Linear(3, 3))
        model.register_buffer("test_buffer", torch.ones((1)) * rank)
        model.to(device)

        next(model.parameters()
             ).requires_grad = False  # Test non-trainable parameters

        optimizer = OSS(params=model.parameters(),
                        optim=torch.optim.SGD,
                        lr=0.01,
                        momentum=0.99)
        ddp_model = ShardedDataParallel(model,
                                        optimizer,
                                        broadcast_buffers=broadcast_buffers)

        def check_same_model_params(same_params: bool):
            # Check that all the params are the same on all ranks
            # This should be true with and without broadcast_buffers, we don't have any real buffer here
            receptacle: List[torch.Tensor] = []

            if dist.get_backend() != "nccl":
                for pg in optimizer.param_groups:
                    for p in pg["params"]:
                        # Check the params
                        receptacle = [p.clone() for _ in range(world_size)
                                      ] if rank == 0 else []
                        dist.gather(p, receptacle, dst=0)
                        if rank == 0:
                            for sync_p in receptacle[1:]:
                                if same_params:
                                    assert torch.all(
                                        torch.eq(receptacle[0], sync_p)
                                    ), "Models differ in between ranks"
                                else:
                                    assert not torch.all(
                                        torch.eq(receptacle[0], sync_p)
                                    ), "Gradients should not have been synced"

                # Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
                if broadcast_buffers:
                    for b in ddp_model.buffers():
                        receptacle = [b.clone() for _ in range(world_size)
                                      ] if rank == 0 else []
                        dist.gather(b, receptacle, dst=0)
                        if rank == 0:
                            for sync_b in receptacle[1:]:
                                if same_params:
                                    assert torch.all(
                                        torch.eq(receptacle[0], sync_b)
                                    ), "Models differ in between ranks"
                                else:
                                    assert not torch.all(
                                        torch.eq(receptacle[0], sync_b)
                                    ), "Gradients should not have been synced"

                        assert b.cpu().item() == 0.0

        # The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
        check_same_model_params(same_params=True)

        # Optim loop
        def closure():
            optimizer.zero_grad()

            with ddp_model.no_sync() if grad_accumulation else suppress():
                input_tensor = torch.rand((64, 2)).to(device)
                loss = ddp_model(input_tensor).abs().sum()
                loss.backward()
            return loss

        # The models should stay the same in between the ranks
        for i in range(5):
            _ = optimizer.step(closure=closure)
            # when running on cpu/gloo the "nodes" are not really different
            same_params = device == torch.device("cpu") or grad_accumulation
            check_same_model_params(same_params=same_params)
Ejemplo n.º 10
0
def run_ddp_parity_two_optim(rank, world_size, backend, temp_file_name,
                             reduce_buffer_size):
    dist.init_process_group(init_method="file://" + temp_file_name,
                            backend=backend,
                            rank=rank,
                            world_size=world_size)
    device = torch.device("cuda")
    torch.cuda.set_device(rank)
    torch.manual_seed(rank)
    np.random.seed(rank)  # Any model works. Add one different buffer per rank

    BATCHS = 20

    model = _get_mlp_emb()
    model.register_buffer("test_buffer", torch.ones((1)) * rank)
    model.to(device)
    n_half_params = len(list(model.parameters())) // 2
    optim_settings = {"lr": 1e-3, "momentum": 0.99}

    sharded_optimizer = OSS(params=list(model.parameters())[:n_half_params],
                            optim=torch.optim.SGD,
                            **optim_settings)
    sharded_optimizer_2 = OSS(params=list(model.parameters())[n_half_params:],
                              optim=torch.optim.SGD,
                              **optim_settings)

    sharded_ddp_model = ShardedDataParallel(
        module=model,
        sharded_optimizer=[sharded_optimizer, sharded_optimizer_2],
        broadcast_buffers=True,
        reduce_buffer_size=reduce_buffer_size,
    )

    ddp_model_single = copy.deepcopy(model)
    ddp_optimizer = torch.optim.SGD(
        list(ddp_model_single.parameters())[:n_half_params], **optim_settings)
    ddp_optimizer_2 = torch.optim.SGD(
        list(ddp_model_single.parameters())[n_half_params:], **optim_settings)
    ddp_model = DDP(ddp_model_single,
                    device_ids=[rank],
                    broadcast_buffers=True)

    check_same_model_params(
        sharded_ddp_model,
        ddp_model,
        f"DDP parity two optim test failing. differing at startup, Buffers {reduce_buffer_size}",
    )

    for i in range(BATCHS):
        input_tensor = _get_random_inputs(device)

        # Run DDP
        ddp_optimizer.zero_grad()
        ddp_optimizer_2.zero_grad()
        ddp_loss = ddp_model(input_tensor).abs().sum()
        ddp_loss.backward()
        ddp_optimizer.step()
        ddp_optimizer_2.step()
        torch.cuda.synchronize(device)

        # Run Sharded
        sharded_optimizer.zero_grad()
        sharded_optimizer_2.zero_grad()
        sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
        sharded_loss.backward()
        sharded_optimizer.step()
        sharded_optimizer_2.step()
        torch.cuda.synchronize(device)

        check_same_model_params(
            sharded_ddp_model,
            ddp_model,
            f"DDP parity two optim test failing, step {i}, buffers {reduce_buffer_size}",
        )

    dist.destroy_process_group()
Ejemplo n.º 11
0
    def check_parity(amp: bool, accumulate: bool, change_train_graph: bool):

        # The API should be the exact same in between the sharded and non-sharded variants, generic closure
        def closure(model, scaler, input_tensor, should_accumulate):
            accumulate_steps = 3 if should_accumulate else 1

            model.zero_grad()

            def step():
                if scaler is not None:
                    with torch.cuda.amp.autocast():
                        loss = model(input_tensor).abs().sum()
                        scaler.scale(loss).backward()
                else:
                    loss = model(input_tensor).abs().sum()
                    loss.backward()

            with model.no_sync() if should_accumulate else suppress():
                for _ in range(accumulate_steps - 1):
                    step()

            step()

        # Any model works. Add one different buffer per rank
        model = Sequential(Linear(INPUTS, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
        model.register_buffer("test_buffer", torch.ones((1)) * rank)
        model.to(device)

        # Make sure that the model starts with non-trainable, so that we check for the buckets to be
        # properly reassigned when/if this changes
        next(model.parameters()).requires_grad = False

        sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-5, momentum=0.99)
        sharded_ddp_model = ShardedDataParallel(
            module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True
        )

        ddp_model_single = copy.deepcopy(model)
        ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-5, momentum=0.99)
        ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True)

        ddp_scaler = TorchGradScaler() if amp else None
        sharded_ddp_scaler = ShardedGradScaler() if amp else None

        # The model should be synchronized in between the ranks at construction time, check that
        check_same_model_params(sharded_ddp_model, ddp_model)

        # Typical training loop, check that we get the exact same results as DDP
        for i in range(NUMBER_BATCHS):
            input_tensor = torch.rand((BATCH_SIZE, INPUTS)).to(device)

            def closure_ddp(input_tensor=input_tensor):
                return closure(ddp_model, ddp_scaler, input_tensor, accumulate)

            def closure_sharded(input_tensor=input_tensor):
                return closure(sharded_ddp_model, sharded_ddp_scaler, input_tensor, accumulate)

            # Step/scale both
            if ddp_scaler is not None:
                _ = closure_ddp(input_tensor)
                ddp_scaler.step(ddp_optimizer)
                ddp_scaler.update()
            else:
                ddp_optimizer.step(closure=closure_ddp)

            if sharded_ddp_scaler is not None:
                _ = closure_sharded(input_tensor)
                sharded_ddp_scaler.step(sharded_optimizer)
                sharded_ddp_scaler.update()
            else:
                sharded_optimizer.step(closure=closure_sharded)

            check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke")

            # Flip the trainability of the first parameter back and forth
            if i == 0 and change_train_graph:
                next(sharded_ddp_model.parameters()).requires_grad = not next(
                    sharded_ddp_model.parameters()
                ).requires_grad
                next(ddp_model.parameters()).requires_grad = not next(ddp_model.parameters()).requires_grad
                check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke")
Ejemplo n.º 12
0
def train(
    rank: int,
    args: argparse.Namespace,
    backend: str = "gloo",
    optim_type: OptimType = OptimType.vanilla,
    check_regression: bool = True,
):
    logging.basicConfig(
        level=logging.INFO if not args.debug else logging.DEBUG)

    # DDP
    dist_init(rank=rank, world_size=args.world_size, backend=backend)

    # Setup
    if not args.cpu:
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed(0)
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    device = torch.device("cpu") if args.cpu else torch.device(rank)
    model, dataloader, loss_fn = get_problem(rank, args.world_size,
                                             args.batch_size, device,
                                             args.torchvision_model)

    # Shard the optimizer
    optimizer: Optional[torch.optim.Optimizer] = None
    model = cast(nn.Module, model)
    scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else
              ShardedGradScaler()) if args.amp else None

    if optim_type == OptimType.oss_sharded_ddp:
        optimizer = OSS(params=model.parameters(),
                        optim=OPTIM,
                        lr=1e-4,
                        momentum=0.9)
        model = ShardedDDP(model, optimizer)
    else:
        device_ids = None if args.cpu else [rank]
        model = DDP(model, device_ids=device_ids,
                    find_unused_parameters=False)  # type: ignore
        optimizer = (OSS(
            params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
                     if optim_type == OptimType.oss_ddp else OPTIM(
                         model.parameters(), lr=1e-4, momentum=0.9))
    optimizer = cast(torch.optim.Optimizer, optimizer)

    # Reset the memory use counter
    if not args.cpu:
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats(rank)
        torch.cuda.synchronize(rank)

    # Standard training loop
    training_start = time.monotonic()
    model.train()

    measurements = []
    final_loss: Optional[float] = -1.0
    need_profiling = args.profile

    for epoch in range(args.epochs):
        n_items = 0
        epoch_runtime = 0.0

        for batch in dataloader:
            if not args.cpu:
                torch.cuda.synchronize(rank)
            batch__start = time.monotonic()

            def closure(data=batch, grad_scaler=None):
                model.zero_grad()
                if args.debug and rank == 0 and next(
                        model.parameters()).grad is not None:
                    logging.debug("\nbefore:  param {} -- grad {}".format(
                        next(model.parameters()).norm().item(),
                        next(model.parameters()).grad.norm().item()))
                if grad_scaler is not None:
                    # Automatically computes the FW pass in half precision
                    with torch.cuda.amp.autocast():
                        outputs = model(data["inputs"])
                        loss = loss_fn(outputs, data["label"])

                        # Accumulates scaled gradients.
                        grad_scaler.scale(loss).backward()
                else:
                    outputs = model(data["inputs"])
                    loss = loss_fn(outputs, data["label"])
                    loss.backward()

                if args.debug and rank == 0 and next(
                        model.parameters()).grad is not None:
                    logging.debug("after BW: param {} -- grad {}".format(
                        next(model.parameters()).norm().item(),
                        next(model.parameters()).grad.norm().item()))
                return loss

            if need_profiling and not args.cpu:
                logging.info("Profiling the run")
                with profiler.profile(
                        use_cuda=True, record_shapes=True,
                        profile_memory=True) as prof:  # type: ignore
                    with profiler.record_function("batch"):
                        if scaler is not None:
                            final_loss = closure(
                                grad_scaler=scaler
                            )  # AMP scaler.step does not support closures
                            scaler.step(optimizer)
                            scaler.update()
                        else:
                            final_loss = optimizer.step(closure)

                prof.export_chrome_trace(
                    f"{optim_type}_trace_rank_{rank}.json")
                need_profiling = False  # only profile once

            else:
                if scaler is not None:
                    final_loss = closure(
                        grad_scaler=scaler
                    )  # AMP scaler.step does not support closures
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    final_loss = optimizer.step(closure)

            if args.debug and rank == 0:
                logging.debug("buffer: {}".format(
                    next(model.buffers()).norm().item()))
                logging.debug("after update: param {} -- grad {}".format(
                    next(model.parameters()).norm().item(),
                    next(model.parameters()).grad.norm().item()))

            n_items += args.batch_size

            if not args.cpu:
                # make sure that the cuda kernels are finished before taking a timestamp
                torch.cuda.synchronize(rank)

            batch_end = time.monotonic()
            epoch_runtime += batch_end - batch__start

        if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp:
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
            optimizer.consolidate_state_dict()
            if dist.get_rank() == 0:
                _ = optimizer.state_dict()
                logging.info("... State dict collected")

        measurements.append(n_items / epoch_runtime)
        if dist.get_rank() == 0:
            logging.info(
                f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}"
            )

    max_memory = -1.0
    if not args.cpu:
        torch.cuda.synchronize(rank)
        max_memory = torch.cuda.max_memory_allocated(rank) / 2**20
        logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

    training_stop = time.monotonic()
    img_per_sec = n_items / (training_stop - training_start) * args.epochs
    logging.info(
        f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint"
    )

    # Compute the median and median of absolute differences img per second
    measurements.sort()
    median = measurements[len(measurements) // 2]

    abs_diff = list(map(lambda x: abs(x - median), measurements))
    abs_diff.sort()
    mad = abs_diff[len(measurements) // 2] if args.epochs > 2 else -1

    logging.info(
        f"[{dist.get_rank()}] : Median speed: {median:.2f} +/- {mad:.2f}")

    if check_regression and dist.get_rank() == 0:
        assert (median +
                3.0 * mad) > args.reference_speed, "Speed regression detected"
        assert max_memory < 1.05 * args.reference_memory, "Memory use regression detected"
        assert abs(cast(float, final_loss) -
                   args.reference_loss) < 1e-3, "Loss regression detected"

        logging.info("[Regression Test] VALID")

    dist.destroy_process_group()  # type: ignore
Ejemplo n.º 13
0
    def check_parity(amp: bool):
        # Any model works. Add one different buffer per rank
        model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
        model.register_buffer("test_buffer", torch.ones((1)) * rank)
        model.to(device)

        sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
        sharded_ddp_model = ShardedDataParallel(
            module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True
        )

        ddp_model_single = copy.deepcopy(model)
        ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-3, momentum=0.99)
        ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)

        ddp_scaler = TorchGradScaler() if amp else None
        sharded_ddp_scaler = ShardedGradScaler() if amp else None

        def check_same_model_params():
            for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
                for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
                    assert torch.allclose(
                        p, ddp_p, atol=1e-3
                    ), f"Model parameters differ in between DDP and ShardedDDP {p} {ddp_p}"

            for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
                assert torch.allclose(
                    b, ddp_b, atol=1e-3
                ), f"Model buffers differ in between DDP and ShardedDDP. AMP {amp}"

        # The model should be synchronized in between the ranks at construction time, check that
        check_same_model_params()

        # The models should stay the same in between the ranks
        for i in range(10):
            input_tensor = torch.rand((64, 2)).to(device)

            def closure_ddp(input_tensor=input_tensor):
                ddp_optimizer.zero_grad()

                if ddp_scaler is not None:
                    with torch.cuda.amp.autocast():
                        ddp_loss = ddp_model(input_tensor).abs().sum()
                        ddp_scaler.scale(ddp_loss).backward()
                else:
                    ddp_loss = ddp_model(input_tensor).abs().sum()
                    ddp_loss.backward()
                return ddp_loss

            def closure_sharded(input_tensor=input_tensor):
                sharded_optimizer.zero_grad()

                if sharded_ddp_scaler is not None:
                    with torch.cuda.amp.autocast():
                        sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
                        sharded_ddp_scaler.scale(sharded_loss).backward()
                else:
                    sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
                    sharded_loss.backward()
                return sharded_loss

            # Step/scale both
            if ddp_scaler is not None:
                _ = closure_ddp(input_tensor)
                ddp_scaler.step(ddp_optimizer)
                ddp_scaler.update()
            else:
                ddp_optimizer.step(closure=closure_ddp)

            if sharded_ddp_scaler is not None:
                _ = closure_sharded(input_tensor)
                sharded_ddp_scaler.step(sharded_optimizer)
                sharded_ddp_scaler.update()
            else:
                sharded_optimizer.step(closure=closure_sharded)

            check_same_model_params()