예제 #1
0
    def test_step(self):
        """ Check that the ZeroRedundancyOptimizer wrapper properly exposes the `.step()` interface"""
        if self.rank > 1 or (BACKEND == dist.Backend.NCCL
                             and torch.cuda.device_count() < 2):
            return

        self.dist_init(self.rank, world_size=2)

        context = suppress(
        ) if not torch.cuda.is_available() else torch.cuda.device(self.rank)

        with context:
            x = torch.tensor([float(self.rank + 1)], device=self.device)
            m = torch.nn.Linear(1, 1)
            m.weight.data = torch.tensor([[1.0]])
            m.bias.data = torch.tensor([2.0])
            m.to(self.device)

            o = ZeroRedundancyOptimizer(m.parameters(), optim=SGD, lr=0.1)
            y = m(x)
            y.backward(x)
            for p in m.parameters():
                dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
                p.grad.data /= self.world_size
            o.step()
            self.assertEqual(m.weight,
                             torch.tensor([[0.75]], device=self.device))
            self.assertEqual(m.bias, torch.tensor([1.85], device=self.device))
예제 #2
0
    def test_state_dict(self):
        """Check that the ZeroRedundancyOptimizer exposes the expected state dict interface,
        irrespective of the sharding.
        """
        self.dist_init(self.rank)
        x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
        o = ZeroRedundancyOptimizer([x],
                                    optimizer_class=SGD,
                                    lr=0.1,
                                    momentum=0.9)
        x.backward()
        o.step()
        self.assertEqual(x, torch.tensor([0.9], device=DEVICE))
        self.assertEqual(o.optim.state[x]["momentum_buffer"],
                         torch.tensor([1.0], device=DEVICE))

        o.zero_grad()
        o.consolidate_state_dict(
        )  # Sync state dict in between replicas - even if there are none
        state_dict = o.state_dict()

        # Check that the state dict is pytorch-compliant key wise
        self.assertIn("param_groups", state_dict.keys())
        self.assertIn("state", state_dict.keys())

        # Check that the pulled state is what we expect, and that we have all the expected keys
        self.assertEqual(state_dict["param_groups"][0]["lr"], 0.1)
        self.assertEqual(state_dict["param_groups"][0]["momentum"], 0.9)
        self.assertFalse(state_dict["param_groups"][0]["nesterov"])
        self.assertEqual(state_dict["param_groups"][0]["weight_decay"], 0.0)
        self.assertEqual(state_dict["param_groups"][0]["dampening"], 0.0)

        # Check that the pulled state and the .param_groups attribute are in sync
        for k in state_dict["param_groups"][0].keys():
            if k != "params":
                self.assertEqual(state_dict["param_groups"][0][k],
                                 o.param_groups[0][k])

        # Check that it's correctly loaded
        o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.01)
        o.load_state_dict(state_dict)

        # Check that state is correct and on proper device
        self.assertEqual(o.optim.state[x]["momentum_buffer"],
                         torch.tensor([1.0], device=DEVICE))

        # We should now be using a lr of 0.1, both within the optimizer
        # and as exposed by the .param_groups attribute
        assert o.param_groups[0]["lr"] == 0.1
        x.backward()
        o.step()
        self.assertEqual(x, torch.tensor([0.71], device=DEVICE))
        self.assertEqual(o.optim.state[x]["momentum_buffer"],
                         torch.tensor([1.9], device=DEVICE))

        # Check that the exposed param_groups are on the proper device
        self.assertEqual(o.param_groups[0]["params"][0].device, x.device)
    def test_step_without_closure(self):
        """Check that the step() method (without closure) is handlded as expected"""
        self.dist_init(self.rank)

        class SGDWithoutClosure(torch.optim.SGD):
            def step(self):
                return super().step()

        x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
        o = ZeroRedundancyOptimizer([x], optimizer_class=SGDWithoutClosure, lr=0.1)
        x.backward()
        o.step()
        self.assertEqual(x, torch.tensor([0.9], device=DEVICE))
    def test_step_with_kwargs(self):
        """ Check that the `step(**kwargs)` interface is properly exposed"""
        self.dist_init(self.rank)

        class SGDWithStepKWArg(torch.optim.SGD):
            def step(self, closure=None, kwarg=None):
                super().step()
                kwarg.append(5)

        kwarg: List[Any] = []
        x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
        o = ZeroRedundancyOptimizer([x], optimizer_class=SGDWithStepKWArg, lr=0.1)
        x.backward()
        o.step(0, kwarg=kwarg)
        self.assertEqual(kwarg, [5])
        self.assertEqual(x, torch.tensor([0.9], device=DEVICE))
    def test_implicit_local_state_dict(self):
        """Check that it's possible to pull a local state dict
        .. warning: probably deprecated in the near future
        """
        self.dist_init(self.rank)

        x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
        o = ZeroRedundancyOptimizer([x], optim=SGD, lr=0.1)
        local_state_dict = o.state_dict()
        o = ZeroRedundancyOptimizer([x], optim=SGD, lr=0.01)
        o.load_state_dict(local_state_dict)
        # We should now be using a lr of 0.1.
        self.assertEqual(o.optim.param_groups[0]["lr"], 0.1)
        self.assertEqual(o.param_groups[0]["lr"], 0.1)
        x.backward()
        o.step()
        self.assertEqual(x, torch.tensor([0.9], device=DEVICE))
예제 #6
0
    def test_step_with_extra_inner_key(self):
        """Check that an optimizer adding extra keys to the param_groups
        is properly handled, in that the new key is exposed to the user
        """
        self.dist_init(self.rank)

        class SGDWithNewKey(torch.optim.SGD):
            # Dummy optimizer which adds a new key to the param groups
            def step(self, closure=None):
                super().step()
                self.param_groups[0]["new_key"] = 0.1

        x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
        o = ZeroRedundancyOptimizer([x], optimizer_class=SGDWithNewKey, lr=0.1)
        x.backward()
        o.step()
        self.assertEqual(o.param_groups[0]["new_key"], 0.1)
        self.assertEqual(x, torch.tensor([0.9], device=DEVICE))
예제 #7
0
    def test_step(self):
        """ Check that the ZeroRedundancyOptimizer wrapper properly exposes the `.step()` interface"""

        if self.rank >= self.world_size or (torch.cuda.is_available()
                                            and torch.cuda.device_count() < 2):
            return

        self.dist_init(self.rank, world_size=self.world_size)

        context = suppress(
        ) if not torch.cuda.is_available() else torch.cuda.device(self.rank)

        with context:
            x = torch.tensor([float(self.rank + 1)], device=self.device)
            m = torch.nn.Linear(1, 1)
            m.weight.data = torch.tensor([[1.0]])
            m.bias.data = torch.tensor([2.0])
            m_zero = copy.deepcopy(m)
            m.to(self.device)
            m_zero.to(self.device)

            lr = 0.1
            o = SGD(m.parameters(), lr=lr)
            o_zero = ZeroRedundancyOptimizer(m_zero.parameters(),
                                             optimizer_class=SGD,
                                             lr=lr)

            y = m(x)
            y.backward(x)
            y_zero = m_zero(x)
            y_zero.backward(x)

            for p in m.parameters():
                dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
                p.grad.data /= self.world_size
            o.step()
            for p in m_zero.parameters():
                dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
                p.grad.data /= self.world_size
            o_zero.step()

            self.assertEqual(m.weight, m_zero.weight)
            self.assertEqual(m.bias, m_zero.bias)
예제 #8
0
    def test_collect_shards(self):
        """ Check the state consolidation mechanism, and the state dict exposed by ZeroRedundancyOptimizer"""
        self.dist_init(self.rank)
        RECIPIENT_RANK = 0

        # Run a dummy step so that the optimizer state dict exists
        batch, input_width, hidden, target_width = 3, 20, 10, 5
        target = torch.rand((batch, target_width), device=self.device)
        inputs = torch.rand((batch, input_width), device=self.device)

        model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden),
                                    torch.nn.Linear(hidden, target_width))
        model.to(self.device)

        loss_fn = torch.nn.L1Loss()
        loss_fn.to(self.device)

        # With SGD, Momentum is required to get a state to shard
        optimizer = ZeroRedundancyOptimizer(model.parameters(),
                                            optimizer_class=SGD,
                                            lr=0.1,
                                            momentum=0.99)

        def closure():
            optimizer.zero_grad()
            output = model(inputs)
            loss = loss_fn(output, target)
            loss.backward()
            return loss

        _ = optimizer.step(closure=closure)

        # Update the optimizer state on the reference rank
        optimizer.consolidate_state_dict(to=RECIPIENT_RANK)

        # Fetch the state on the reference rank
        # - check that it has the correct size
        # - load it again
        if self.rank == RECIPIENT_RANK:
            optimizer_state_dict = optimizer.state_dict()
            self.assertEqual(len(optimizer_state_dict["state"]),
                             len(list(model.parameters())))
        else:
            optimizer_state_dict = {}

        optimizer_state_dict = _broadcast_object(
            optimizer_state_dict,
            src_rank=RECIPIENT_RANK,
            group=dist.group.WORLD,
            device=self.device,
        )

        # Load the optimizer state dict, check that no exception is raised
        optimizer.load_state_dict(optimizer_state_dict)
예제 #9
0
    def test_lr_scheduler(self):
        """ Check that a normal torch lr_scheduler is usable with ZeroRedundancyOptimizer"""

        self.dist_init(self.rank)
        x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
        x2 = torch.tensor([1.0], device=DEVICE, requires_grad=True)
        o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.01)
        o2 = torch.optim.SGD([x2], lr=0.01)
        s = torch.optim.lr_scheduler.StepLR(o, 1)
        s2 = torch.optim.lr_scheduler.StepLR(o2, 1)
        for _ in range(5):
            x.backward()
            o.zero_grad()
            o.step()
            s.step()
            x2.backward()
            o2.zero_grad()
            o2.step()
            s2.step()
            self.assertEqual(x, x2)
예제 #10
0
    def test_step_with_closure(self):
        """ Check that the ZeroRedundancyOptimizer wrapper properly exposes the `.step(closure)` interface"""

        if self.rank >= self.world_size or (BACKEND == dist.Backend.NCCL
                                            and torch.cuda.device_count() < 2):
            return

        self.dist_init(self.rank, world_size=self.world_size)

        context = suppress(
        ) if not torch.cuda.is_available() else torch.cuda.device(self.rank)

        with context:
            for bucket_view in [False, True]:
                x_val = self.rank + 1
                weight = 1.0
                bias = 2.0
                error = 1.0
                target = torch.tensor([x_val * weight + bias + error],
                                      device=self.device)
                loss_fn = torch.nn.L1Loss()

                x = torch.tensor([float(x_val)], device=self.device)
                m = torch.nn.Linear(1, 1)
                m.weight.data = torch.tensor([[weight]])
                m.bias.data = torch.tensor([bias])
                m.to(self.device)

                o = ZeroRedundancyOptimizer(
                    m.parameters(),
                    optimizer_class=SGD,
                    parameters_as_bucket_view=bucket_view,
                    lr=0.1,
                )

                y = m(x)
                y.backward(x)
                for p in m.parameters():
                    dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
                    p.grad.data /= self.world_size

                def closure():
                    o.zero_grad()
                    output = m(x)
                    loss = loss_fn(output, target)
                    loss.backward()
                    return loss

                loss = o.step(closure=closure)

                self.assertEqual(loss, torch.tensor(error))
                self.assertEqual(m.weight, torch.tensor([[1.1]]))
                self.assertEqual(m.bias, torch.tensor([2.1]))
예제 #11
0
def train_distributed(
    rank: int,
    world_size: int,
    lr: float=2e-4,
    batch_size: int=1000,
    epochs: int=500,
    interval: int=10,
    save: int=0,
    num_workers: int=4,
    num_basis: int=100,
    dataset: str='datasets',
    coefficient_noise: bool=False,
    use_zero: bool=False,
    verbose: bool=False,
):
    assert 0 < batch_size, "batch_size must be a positive integer."
    assert 0 < epochs, "epochs must be a positive integer."
    assert (0 <= interval) and (interval <= epochs), "Interval must be a non-negative integer between 0 and epochs."
    assert (0 <= save) and (save <= epochs), "Save must be a non-negative integer between 0 and epochs."


    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False

    # setup data distributed parallel training
    setup_nccl(rank, world_size)  # world size is total gpus
    torch.cuda.set_device(rank)  # rank is gpu index

    # directories
    data_dir = Path(f'/mnt/datahole/daniel/gravflows/{dataset}/train/')
    log_dir = f"{datetime.now().strftime('%b%d_%H-%M-%S')}_{os.uname().nodename}"
    save_dir = Path('lfigw/model_weights/')
    experiment_dir = save_dir / log_dir
    experiment_dir.mkdir(parents=True, exist_ok=True)

    # config files
    waveform_params_ini = str(data_dir / 'config_files/parameters.ini')
    extrinsics_ini = 'gwpe/config_files/extrinsics.ini'
    static_args_ini = 'gwpe/config_files/static_args.ini'

    # tensorboard
    if rank == 0:
        tb = SummaryWriter(f'lfigw/runs/{log_dir}')

    # training data
    dataset = lfigwWaveformDataset(
        n=num_basis,
        data_dir='lfigw/data/train',
        basis_dir='lfigw/data/basis/',
        psd_dir='data/events/GW150914',
        data_file='coefficients.npy',
        static_args_ini=static_args_ini,
        intrinsics_ini=waveform_params_ini,
        extrinsics_ini=extrinsics_ini,
        ifos=['H1','L1'],
        ref_ifo='H1',
        downcast=True,
        add_noise=True,
        coefficient_noise=coefficient_noise,
        distance_scale=True,
        time_shift=False,
    )

    sampler = DistributedSampler(
        dataset,
        shuffle=True,
        num_replicas=world_size,
        rank=rank,
        seed=rank,
    )

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=True,
        persistent_workers=True,
        sampler=sampler,
        num_workers=num_workers,
        prefetch_factor=4,
        worker_init_fn=dataset._worker_init_fn,
        collate_fn=dataset._collate_fn,
    )

    # instantiate neural spline coupling flow
    flow = flows.create_NDE_model(
        input_dim=14,  # we do not predict coalescence time 
        context_dim=4*100,
        num_flow_steps=15,
        base_transform_kwargs={
            'base_transform_type': 'rq-coupling',
            'batch_norm': True,
            'num_transform_blocks': 10,
            'activation': 'elu',
            'hidden_dim': 512,  # default
            'dropout_probability': 0.0,  # default
            'num_bins': 8,  # default
            'tail_bound': 1.0,  # default
            'apply_unconditional_transform': False,  # default
        }
    )

    flow = flow.to(rank)
    # print_peak_memory("Max memory allocated after creating local model", rank)

    # sync_bn_flow = nn.SyncBatchNorm.convert_sync_batchnorm(flow)
    flow = DDP(flow, device_ids=[rank], output_device=rank)

    # print_peak_memory("Max memory allocated after creating DDP", rank)

    if use_zero:
        #https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html
        from torch.distributed.optim import ZeroRedundancyOptimizer
        optimizer = ZeroRedundancyOptimizer(
            flow.parameters(),
            optimizer_class=torch.optim.Adam,
            lr=lr,
            parameters_as_bucket_view=True,
        )
    else:
        optimizer = torch.optim.Adam(flow.parameters(), lr=lr)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)


    disable_pbar = False if verbose and (rank == 0) else True  # tqdm progress bar
    with tqdm(
        total=len(dataloader)*epochs,
        disable=disable_pbar,
        ncols=150,
        postfix={'epoch': 0},
        desc=f'[{log_dir}] Training'
    ) as progress:

        # run training loop
        train_loss = torch.zeros((1,), device=rank, requires_grad=False)

        for epoch in range(1, 1+epochs):
            flow.train()
            distributed.barrier()

            for coefficients, parameters in dataloader:
                if rank == 0:
                    progress.set_postfix({'epoch': epoch})
                    progress.set_description(f'[{log_dir}] Training', refresh=True)

                optimizer.zero_grad()
                
                coefficients = coefficients.to(rank, non_blocking=True)
                parameters = parameters.to(rank, non_blocking=True)

                # negative log-likelihood conditional on strain over mini-batch
                loss = -flow.module.log_prob(parameters, context=coefficients).mean()
                
                loss.backward()

                # print_peak_memory("Max memory allocated before optimizer step()", rank)
                optimizer.step()
                # print_peak_memory("Max memory allocated after optimizer step()", rank)

                # total loss summed over each sample in batch (scaled to lfigw)
                train_loss += loss.detach() * coefficients.shape[0] * (15/14)
                progress.update(1)

            scheduler.step()

            # gather total loss during epoch between each GPU worker as list of tensors
            world_loss = [torch.ones_like(train_loss) for _ in range(world_size)]
            distributed.all_gather(world_loss, train_loss)
            train_loss *= 0.0  # reset loss for next epoch

            if rank == 0:
                epoch_loss = torch.cat(world_loss).sum().item() / len(dataloader.dataset)
                tb.add_scalar('loss/train', epoch_loss, epoch)
                # if (interval != 0) and (epoch % interval == 0):

                if (save != 0) and (epoch % save == 0):
                    # save checkpoint and write computationally expensive data to tb
                    torch.save(flow.module.state_dict(), experiment_dir / f'flow_{epoch}.pt')
                    torch.save(optimizer.state_dict(), experiment_dir / f'optimizer_{epoch}.pt')
                    torch.save(scheduler.state_dict(), experiment_dir / f'scheduler_{epoch}.pt')


    cleanup_nccl()
예제 #12
0
    def _test_zero_model_parallel(self, parameters_as_bucket_view: bool):
        # Use two processes each with two GPUs
        assert self.rank < 2
        NUM_EPOCHS = 3
        NUM_INPUTS = 5
        LR = 0.01
        torch.manual_seed(0)
        torch.cuda.manual_seed(0)

        class ModelParallelModel(torch.nn.Module):
            def __init__(self, dev0, dev1):
                super().__init__()
                self.dev0 = dev0
                self.dev1 = dev1
                self.net0 = torch.nn.Linear(10, 10).to(dev0)
                self.relu = torch.nn.ReLU()
                self.net1 = torch.nn.Linear(10, 5).to(dev1)

            def forward(self, x):
                x = x.to(self.dev0)
                x = self.relu(self.net0(x))
                x = x.to(self.dev1)
                return self.net1(x)

        class LocalModel(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.net0 = torch.nn.Linear(10, 10)
                self.relu = torch.nn.ReLU()
                self.net1 = torch.nn.Linear(10, 5)

            def forward(self, x):
                return self.net1(self.relu(self.net0(x)))

        dev0 = 2 * self.rank
        dev1 = 2 * self.rank + 1
        mp_model = ModelParallelModel(dev0, dev1)
        ddp_model = DDP(mp_model)
        local_model = LocalModel()
        cpu_device = torch.device("cpu")
        # Ensure the parameters are the same across the two models
        local_model.net0.weight = torch.nn.Parameter(
            mp_model.net0.weight.detach().clone().to(cpu_device))
        local_model.net0.bias = torch.nn.Parameter(
            mp_model.net0.bias.detach().clone().to(cpu_device))
        local_model.net1.weight = torch.nn.Parameter(
            mp_model.net1.weight.detach().clone().to(cpu_device))
        local_model.net1.bias = torch.nn.Parameter(
            mp_model.net1.bias.detach().clone().to(cpu_device))

        # Compare parity between DDP with model parallelism using ZeRO and
        # a local model using a local optimizer
        zero_optim = ZeroRedundancyOptimizer(
            ddp_model.parameters(),
            optimizer_class=torch.optim.Adam,
            parameters_as_bucket_view=parameters_as_bucket_view,
            lr=LR)
        local_optim = torch.optim.Adam(local_model.parameters(), lr=LR)
        inputs = [torch.randn(20, 10) for _ in range(NUM_INPUTS)]

        for _ in range(NUM_EPOCHS):
            for input in inputs:

                def closure_local():
                    local_optim.zero_grad()
                    local_loss = local_model(input).abs().sum()
                    local_loss.backward()
                    return local_loss

                def closure_ddp():
                    zero_optim.zero_grad()
                    ddp_loss = ddp_model(input).abs().sum()
                    ddp_loss.backward()
                    return ddp_loss

                local_loss = cast(torch.Tensor,
                                  local_optim.step(closure=closure_local))
                ddp_loss = cast(
                    torch.Tensor,
                    zero_optim.step(closure=closure_ddp)).to(cpu_device)

                assert torch.allclose(
                    local_loss, ddp_loss
                ), "Losses differ between local optim and ZeroRedundancyOptimizer"

                for local_p, ddp_p in zip(local_model.parameters(),
                                          ddp_model.parameters()):
                    ddp_p = ddp_p.to(cpu_device)
                    assert torch.allclose(local_p,
                                          ddp_p), "Models differ after a step"
예제 #13
0
    def _test_zero_join(self, device):
        r"""
        Check that the ZeRO join hook allows training with uneven inputs when using the given device.

        Arguments:
            device (torch.device): device used to store parameters and perform
                collective communications.
        """
        NUM_INPUTS = 3
        NUM_EPOCHS = 2
        torch.manual_seed(0)
        torch.cuda.manual_seed(0)

        rank = self.rank
        world_size = self.world_size
        is_gpu = device.type == "cuda"
        backend = dist.Backend.NCCL if is_gpu else dist.Backend.GLOO
        self.dist_init(rank, world_size, backend)
        if BACKEND == dist.Backend.NCCL and is_gpu:
            torch.cuda.set_device(self.device)

        model = torch.nn.Sequential(
            torch.nn.Linear(2, 3),
            torch.nn.Linear(3, 3),
            torch.nn.Linear(3, 3),
        )
        model.to(device)

        # DDP ensures correct gradients in data parallel training, so DDP with
        # local optimizers on uneven inputs should be equivalent to ZeRO on
        # uneven inputs with gradients being manually set
        ddp_model = DDP(model, device_ids=[rank]) if is_gpu else DDP(model)
        local_optim = torch.optim.Adam(ddp_model.parameters(), lr=0.01)
        zero_model = copy.deepcopy(model)
        zero_model.to(device)
        zero_optim = ZeroRedundancyOptimizer(zero_model.parameters(),
                                             torch.optim.Adam,
                                             lr=0.01)
        loss_fn = torch.nn.MSELoss()

        # Use uneven inputs: rank i has i extra inputs
        inputs = [
            torch.randn(20, 2).to(device) for _ in range(NUM_INPUTS + rank)
        ]
        labels = torch.randn(20, 3).to(device)

        # Save the gradients and parameters from DDP as the ground truth; do
        # so on the last-joining rank (in this case, the largest rank)
        grads_at_each_iter = []
        params_at_each_iter = []
        with ddp_model.join():
            for _ in range(NUM_EPOCHS):
                for input in inputs:
                    output = ddp_model(input)
                    loss_fn(output, labels).backward()
                    if rank == world_size - 1:
                        grads = []
                        for p in ddp_model.parameters():
                            grads.append(p.grad.detach().clone().to(device))
                    local_optim.step()
                    if rank == world_size - 1:
                        params = []
                        for p in ddp_model.parameters():
                            params.append(p.detach().clone().to(device))
                        grads_at_each_iter.append(grads)
                        params_at_each_iter.append(params)

        # Broadcast the saved gradients and parameters to all of the other
        # ranks (which joined early)
        grads_and_params = [grads_at_each_iter, params_at_each_iter]
        grads_and_params = _broadcast_object(grads_and_params,
                                             src_rank=world_size - 1,
                                             group=dist.group.WORLD,
                                             device=device)
        grads_at_each_iter = grads_and_params[0]
        params_at_each_iter = grads_and_params[1]

        # TODO: Replace this `_broadcast_object` with `broadcast_object_list`
        # once the latter supports loading to the destination device instead
        # of the source device

        # A process must still set the remaining gradients after joining, so we
        # define a join hook to do this before the ZeRO join hook
        class _JoinGradInfo():
            def __init__(self, grads):
                self.grads = grads  # remaining gradients to set (in order)
                self.index = 0

        class _SetGradsJoinHook(JoinHook):
            def __init__(self, zero_optim, grads):
                zero_optim._join_grad_info = _JoinGradInfo(grads)
                self.zero = zero_optim
                super().__init__()

            def main_hook(self):
                grads = self.zero._join_grad_info.grads[
                    self.zero._join_grad_info.index]
                self.zero._join_grad_info.index += 1
                for p, grad in zip(self.zero._all_params, grads):
                    p.grad = grad.detach().clone().to(device)

        class _GradientSetter(Joinable):
            def __init__(self):
                super().__init__()

            def join_hook(self, **kwargs):
                assert "zero_optim" in kwargs
                assert "grads" in kwargs
                zero_optim = kwargs["zero_optim"]
                grads = kwargs["grads"]
                return _SetGradsJoinHook(zero_optim, grads)

            @property
            def join_device(self):
                return device

            @property
            def join_process_group(self):
                return dist.group.WORLD

        num_grads_after_joining = NUM_EPOCHS * (world_size - rank - 1)
        grads = grads_at_each_iter[-num_grads_after_joining:]
        gradient_setter = _GradientSetter()
        iter = 0
        with Join([gradient_setter, zero_optim],
                  zero_optim=zero_optim,
                  grads=grads):
            for _ in range(NUM_EPOCHS):
                for input in inputs:
                    # Notify join context that this process has not joined
                    Join.notify_join_context(gradient_setter)

                    # Set gradients manually
                    for p, grad in zip(zero_model.parameters(),
                                       grads_at_each_iter[iter]):
                        p.grad = grad.detach().clone().to(device)

                    # Perform optimizer step and check parity
                    zero_optim.step()
                    for p, ddp_p in zip(zero_model.parameters(),
                                        params_at_each_iter[iter]):
                        assert torch.allclose(p, ddp_p), \
                            "Parameters differ between using ZeRO and local optimizer"
                    iter += 1
예제 #14
0
    def _test_ddp_zero_overlap(self, device, hook_constructor,
                               gradient_as_bucket_view, static_graph):
        SGD_LR = 0.01
        SGD_MOMENTUM = 0.9
        SGD_WEIGHT_DECAY = 0.001
        NUM_INPUTS = 5
        torch.manual_seed(0)
        torch.cuda.manual_seed(0)

        rank = self.rank
        is_gpu = device.type == "cuda"
        if BACKEND == dist.Backend.NCCL and is_gpu:
            torch.cuda.set_device(device)
        models_to_test = [
            (torch.nn.Sequential(torch.nn.Linear(1000, 2000),
                                 torch.nn.Linear(2000, 500)),
             [torch.randn(1, 1000).to(device) for _ in range(NUM_INPUTS)]),
        ]
        if HAS_TORCHVISION:
            models_to_test.append((torchvision.models.resnet50(), [
                torch.randn(1, 3, 3, 1000).to(device)
                for _ in range(NUM_INPUTS)
            ]))
        for (model, inputs) in models_to_test:
            # Enable determinism in cudnn operators
            with torch.backends.cudnn.flags(enabled=True,
                                            deterministic=True,
                                            benchmark=False):
                device_ids = [rank] if is_gpu else None
                # Set up the DDP model overlapping with ZeRO
                ddp_model_overlap = DDP(
                    copy.deepcopy(model).to(device),
                    device_ids=device_ids,
                    gradient_as_bucket_view=gradient_as_bucket_view)
                if static_graph:
                    ddp_model_overlap._set_static_graph()
                zero_optim = ZeroRedundancyOptimizer(
                    ddp_model_overlap.parameters(),
                    optimizer_class=torch.optim.SGD,
                    overlap_with_ddp=True,
                    lr=SGD_LR,
                    momentum=SGD_MOMENTUM,
                    weight_decay=SGD_WEIGHT_DECAY,
                    _allow_empty_param_list=True)
                ddp_model_overlap.register_comm_hook(
                    None,
                    hook_constructor(allreduce_hook, ddp_model_overlap,
                                     zero_optim))

                # Set up the DDP model with local optimizer
                ddp_model_local = DDP(
                    copy.deepcopy(model).to(device),
                    device_ids=device_ids,
                    gradient_as_bucket_view=gradient_as_bucket_view)
                if static_graph:
                    ddp_model_local._set_static_graph()
                local_optim = torch.optim.SGD(ddp_model_local.parameters(),
                                              lr=SGD_LR,
                                              momentum=SGD_MOMENTUM,
                                              weight_decay=SGD_WEIGHT_DECAY)

                # Check that the parameters match initially
                for p1, p2 in zip(ddp_model_overlap.parameters(),
                                  ddp_model_local.parameters()):
                    self.assertEqual(p1, p2)

                # Save the parameters to ensure they were updated
                init_params_overlap = copy.deepcopy(
                    list(ddp_model_overlap.parameters()))

                # Ensure that this test runs independently
                dist.barrier()

                # Run the DDP model overlapping with ZeRO
                # NOTE: Overlapping currently requires 2 or 3 warmup iterations
                # to ensure DDP buckets have been rebuilt (depending on the
                # value of `static_graph`)
                num_warmup_inputs = 2 if not static_graph else 3
                for input in inputs[:num_warmup_inputs]:
                    output = ddp_model_overlap(input)
                    loss = output.sum()
                    loss.backward()
                    zero_optim.step()
                for input in inputs:
                    zero_optim.zero_grad()
                    output = ddp_model_overlap(input)
                    loss = output.sum()
                    loss.backward()
                    zero_optim.step()

                # Run the DDP model with local optimizer
                for input in inputs:
                    local_optim.zero_grad()
                    output = ddp_model_local(input)
                    loss = output.sum()
                    loss.backward()
                    local_optim.step()
                dist.barrier()

                # Check that the parameters are equal
                for p1, p2 in zip(ddp_model_overlap.parameters(),
                                  ddp_model_local.parameters()):
                    self.assertEqual(p1, p2)

                # Check that the parameters were updated
                self.assertNotEqual(init_params_overlap,
                                    list(ddp_model_overlap.parameters()))

                # Ensure that this test runs independently
                dist.barrier()
예제 #15
0
파일: ddp.py 프로젝트: c3se/alvis-intro
def run_process():
    '''Run process

    This is what is actually run on each process.
    '''
    # Get distributed parameters
    rank = dist.get_rank()
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = dist.get_world_size()

    # Initialize data_loader
    context_size = 512
    batch_size = 32
    corpus_length = 1024
    vocab_size = 2**8

    dataset = RandomCorpus(corpus_length, context_size, vocab_size)
    sampler = DistributedSampler(dataset, shuffle=True, drop_last=False)
    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        sampler=sampler,
    )

    # Initialize model
    model = GPT(vocab_size, context_size, verbose=True)

    device = torch.device(f"cuda:{local_rank}")
    model.to(device)

    # Prepare for distributed data parallelism
    model = DistributedDataParallel(model,
                                    device_ids=[rank],
                                    output_device=rank)

    # The learning rate is adapted for the total batch_size in tokens
    learning_rate = 6e-4 * (batch_size * world_size * context_size / 5e5)
    # ZeroRedundancyOptimizer reduces the memory footprint of the Optimizer
    opt = ZeroRedundancyOptimizer(
        model.parameters(),
        optimizer_class=optim.Adam,
        lr=learning_rate,
    )
    loss_func = nn.CrossEntropyLoss()

    # Initialize logger instance to see performance
    writer = BenchmarkWriter()

    # Actual training
    global_step = 0
    n_epochs = 10
    for epoch in range(n_epochs):
        model.train()
        sampler.set_epoch(epoch)  # for correct shuffling
        for sequence, in data_loader:
            opt.zero_grad()

            # Shift so that prediction is next token for each token
            sequence = sequence.to(device)
            logits = model(sequence[..., :-1].contiguous())
            target = sequence[..., 1:].contiguous()

            # Flatten the tokens when calculating loss
            loss = loss_func(
                logits.flatten(end_dim=-2),
                target.flatten(),
            )
            loss.backward()
            opt.step()

            # This will also log the wall time
            if rank == 0:
                global_step += batch_size * world_size
                writer.add_scalar("Loss", loss.item(), global_step=global_step)

        if rank == 0:
            print("Epoch:", epoch)

    if rank == 0:
        writer.benchmark_results(burn_in_steps=2 * corpus_length,
                                 step_unit="seq")
    writer.close()

    return model
예제 #16
0
def train(
    rank: int,
    world_size: int,
    lr: float = 5e-4,
    batch_size: int = 1000,
    epochs: int = 500,
    interval: int = 10,
    save: int = 100,
    num_workers: int = 4,
    num_basis: int = 100,
    dataset: str = 'datasets',
    load_dir: Optional[str] = None,
    load_epoch: Optional[int] = None,
    coefficient_noise: bool = False,
    verbose: bool = False,
    use_zero: bool = False,
):
    assert 0 < batch_size, "batch_size must be a positive integer."
    assert 0 < epochs, "epochs must be a positive integer."
    assert (0 <= interval) and (
        interval <= epochs
    ), "Interval must be a non-negative integer between 0 and epochs."
    assert (0 <= save) and (
        save <=
        epochs), "Save must be a non-negative integer between 0 and epochs."

    # setup data distributed parallel training
    setup_nccl(rank, world_size)  # world size is total gpus
    torch.cuda.set_device(rank)  # rank is gpu index

    # directories
    if rank == 0: print(f"Loading data from {dataset}...")
    data_dir = Path(f'/mnt/datahole/daniel/gravflows/{dataset}/train/')
    val_dir = Path(f'/mnt/datahole/daniel/gravflows/{dataset}/validation/')

    noise_dir = Path('/mnt/datahole/daniel/gwosc/O1')
    psd_dir = Path(f"/mnt/datahole/daniel/gravflows/{dataset}/train/PSD/")
    basis_dir = Path(f'/mnt/datahole/daniel/gravflows/{dataset}/basis/')

    log_dir = f"{datetime.now().strftime('%b%d_%H-%M-%S')}_{os.uname().nodename}"

    save_dir = Path('gwpe/model_weights/')
    experiment_dir = save_dir / log_dir
    experiment_dir.mkdir(parents=True, exist_ok=True)

    # config files
    waveform_params_ini = str(data_dir / 'config_files/parameters.ini')
    extrinsics_ini = 'gwpe/config_files/extrinsics.ini'

    static_args_ini = str(data_dir / 'config_files/static_args.ini')

    # validation

    # training data
    # dataset = BasisCoefficientsDataset(
    #     data_dir=data_dir,
    #     basis_dir=basis_dir,
    #     static_args_ini=static_args_ini,
    #     parameters_ini=waveform_params_ini,
    # )

    # dataset = BasisEncoderDataset(
    #     n=num_basis,
    #     data_dir=data_dir,
    #     basis_dir=basis_dir,
    #     static_args_ini=static_args_ini,
    #     intrinsics_ini=waveform_params_ini,
    #     extrinsics_ini=extrinsics_ini,
    #     psd_dir=psd_dir,
    #     ifos=['H1','L1'],
    #     ref_ifo='H1',
    #     downcast=True,
    #     add_noise=True,
    #     coefficient_noise=coefficient_noise,
    # )

    dataset = LFIGWDataset(
        n=100,
        data_dir=data_dir,
        basis_dir=basis_dir,
        static_args_ini=static_args_ini,
        data_file='coefficients.npy',
        intrinsics_ini=waveform_params_ini,
        extrinsics_ini=extrinsics_ini,
        psd_dir=psd_dir,
        ifos=['H1', 'L1'],
        ref_ifo='H1',
        downcast=True,
        add_noise=True,
        distance_scale=True,
        time_shift=False,
    )

    sampler = DistributedSampler(
        dataset,
        shuffle=False,
        num_replicas=world_size,
        rank=rank,
        seed=rank,
    )

    dataloader = DataLoader(
        dataset,
        shuffle=False,
        num_workers=num_workers,
        batch_size=batch_size,
        sampler=sampler,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2,
        worker_init_fn=dataset._worker_init_fn,
        collate_fn=dataset._collate_fn,
    )

    # validation data
    val_dataset = LFIGWDataset(
        n=100,
        data_dir=data_dir,
        basis_dir=basis_dir,
        static_args_ini=static_args_ini,
        data_file='coefficients.npy',
        intrinsics_ini=waveform_params_ini,
        extrinsics_ini=extrinsics_ini,
        psd_dir=psd_dir,
        ifos=['H1', 'L1'],
        ref_ifo='H1',
        downcast=True,
        add_noise=True,
        coefficient_noise=coefficient_noise,
        distance_scale=True,
        time_shift=False,
    )

    # val_dataset = BasisCoefficientsDataset(
    #     data_dir=val_dir,
    #     basis_dir=basis_dir,
    #     static_args_ini=static_args_ini,
    #     parameters_ini=[waveform_params_ini, extrinsics_ini],
    #     coefficient_noise=coefficient_noise,
    # )

    val_sampler = DistributedSampler(
        val_dataset,
        shuffle=False,
        num_replicas=world_size,
        rank=rank,
        seed=rank,
    )

    val_loader = DataLoader(
        val_dataset,
        shuffle=False,
        num_workers=num_workers,
        batch_size=batch_size,
        sampler=val_sampler,
        pin_memory=True,
        prefetch_factor=4,
        worker_init_fn=val_dataset._worker_init_fn,
        collate_fn=val_dataset._collate_fn,
    )

    # validation data
    if interval != 0:

        # specify indices in validation dataset to validate samples
        min_idx = val_dataset.parameters.distance.argmin()
        max_idx = val_dataset.parameters.distance.argmax()
        median_idx = val_dataset.parameters.loc[
            val_dataset.parameters.distance == val_dataset.parameters.distance.
            quantile(interpolation='nearest')].index[0]

        if rank == 0:
            figure_titles = [
                'GW150914', 'Min Distance', f'Median Distance', f'Max Distance'
            ]

            # validation ground truths for posterior sampling
            val_gts = torch.stack([
                torch.zeros(len(val_dataset.parameters.columns),
                            dtype=torch.float32),  # gw150914 dummy gt
                torch.tensor(val_dataset.parameters.iloc[min_idx].values,
                             dtype=torch.float32),  # rank 1
                torch.tensor(val_dataset.parameters.iloc[median_idx].values,
                             dtype=torch.float32),  # rank 2
                torch.tensor(val_dataset.parameters.iloc[max_idx].values,
                             dtype=torch.float32),  # rank 3
            ])

        with torch.no_grad():
            # load data from file manually (rather than using val_dataset._worker_init_fn)
            val_coefficients = np.load(val_dataset.data_dir /
                                       val_dataset.data_file,
                                       mmap_mode='c')

            # generate coefficients on cpu - we want to send this to tensorboard (rank 0) before sending to gpus
            val_coefficients = torch.cat([
                torch.from_numpy(
                    generate_gw150914_context(num_basis, noise_dir, psd_dir,
                                              basis_dir,
                                              static_args_ini))[None],
                torch.tensor(val_coefficients[[min_idx, median_idx, max_idx]]),
            ],
                                         dim=0).to(dtype=torch.complex64)

            # place one of each stacked tensor onto corresponding gpu rank
            val_context = val_coefficients[
                rank] * val_dataset.standardization[:, :num_basis]
            val_context = val_context.to(device=rank)
            val_context = torch.cat([val_context.real, val_context.imag],
                                    dim=0)
            val_context = val_context.reshape(val_context.shape[0] *
                                              val_context.shape[1])[None]

    else:
        figure_titles = None
        val_gts = None
        val_coefficients = None

    # set torch profiling runs
    # wait = 1  # ignore first batch
    # warmup = 1
    # active = 4
    # repeat = 2

    # tensorboard
    if rank == 0:
        # tb = SummaryWriter(f'gwpe/runs/{log_dir}')
        queue = mp.SimpleQueue()
        tb_process = mp.Process(target=tensorboard_writer,
                                args=(
                                    queue,
                                    f'gwpe/runs/{log_dir}',
                                    val_dataset.generator.parameters,
                                    val_dataset.generator.latex,
                                    static_args_ini,
                                    basis_dir,
                                    num_basis,
                                    val_coefficients,
                                    val_gts,
                                    figure_titles,
                                ))
        tb_process.start()

    # instantiate neural spline coupling flow
    flow = flows.create_NDE_model(
        input_dim=14,  # we do not predict coalescence time 
        context_dim=4 * num_basis,
        num_flow_steps=15,
        base_transform_kwargs={
            'base_transform_type': 'rq-coupling',
            'batch_norm': True,
            'num_transform_blocks': 10,
            'activation': 'elu',
        })

    flow = flow.to(rank)
    print_peak_memory("Max memory allocated after creating local model", rank)

    # sync_bn_flow = nn.SyncBatchNorm.convert_sync_batchnorm(flow)
    flow = DDP(flow, device_ids=[rank], output_device=rank)

    print_peak_memory("Max memory allocated after creating DDP", rank)

    if use_zero:
        #https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html
        from torch.distributed.optim import ZeroRedundancyOptimizer
        optimizer = ZeroRedundancyOptimizer(
            flow.parameters(),
            optimizer_class=torch.optim.Adam,
            lr=lr,
            parameters_as_bucket_view=True,
        )
        # optimizer = torch.optim.Adam(flow.parameters(), lr=lr)
    else:
        optimizer = torch.optim.Adam(flow.parameters(), lr=lr)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           T_max=epochs)

    if load_dir is not None and load_epoch is not None:
        print(f'Loading model from {load_dir} at epoch {load_epoch}.')
        flow.module.load_state_dict(
            torch.load(f'gwpe/model_weights/{load_dir}/flow_{load_epoch}.pt',
                       map_location=rank))
        optimizer.load_state_dict(
            torch.load(
                f'gwpe/model_weights/{load_dir}/optimizer_{load_epoch}.pt',
                map_location=rank))
        if Path(f'gwpe/model_weights/{load_dir}/scheduler_{load_epoch}.pt'
                ).is_file():
            scheduler.load_state_dict(
                torch.load(
                    f'gwpe/model_weights/{load_dir}/scheduler_{load_epoch}.pt',
                    map_location=rank))

    # run training loop
    flow.train()
    train_loss = torch.zeros((1, ), device=rank, requires_grad=False)
    val_loss = torch.zeros((1, ), device=rank, requires_grad=False)

    disable_pbar = False if verbose and (rank
                                         == 0) else True  # tqdm progress bar
    with tqdm(total=len(dataloader) * epochs,
              disable=disable_pbar,
              desc=f'[{log_dir}] Training',
              postfix={'epoch': 0}) as progress:
        # with torch.profiler.profile(
        #     activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
        #     schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat),
        #     on_trace_ready=torch.profiler.tensorboard_trace_handler(f'gwpe/runs/{log_dir}'),
        #     record_shapes=True,
        #     with_stack=True
        # ) as profiler:

        for epoch in range(1, 1 + epochs):
            if rank == 0:
                progress.set_postfix({'epoch': epoch})
                progress.set_description(f'[{log_dir}] Training', refresh=True)

            # let all processes sync up before starting with a new epoch of training
            flow.train()
            distributed.barrier()

            iterator = iter(dataloader)
            coefficients, parameters = next(iterator)

            coefficients = coefficients.to(rank, non_blocking=True)
            parameters = parameters.to(rank, non_blocking=True)

            complete = False
            while not complete:
                optimizer.zero_grad()

                # if profile:
                # https://github.com/guyang3532/kineto/blob/readme/tb_plugin/docs/gpu_utilization.md
                ## WARNING: profiler may not handle async pinned memory transfer properly?
                # i.e. may not record CPU vs GPU wall times correctly
                # may be related to reported blocks per SM/achieved occupancy negative bug
                # this was an open issue for pytorch 1.9 as of july 9 - nightly may fix it
                # https://github.com/pytorch/kineto/issues/325#issuecomment-869362218
                # if (step >= (wait + warmup + active) * repeat):
                #     break

                # negative log-likelihood conditional on strain over mini-batch
                loss = -flow.module.log_prob(parameters,
                                             context=coefficients).mean()

                try:
                    # async get data from CPU and move to GPU during model forward
                    coefficients, parameters = next(iterator)

                    coefficients = coefficients.to(rank, non_blocking=True)
                    parameters = parameters.to(rank, non_blocking=True)

                except StopIteration:
                    # exit while loop if iterator is complete
                    complete = True

                loss.backward()

                print_peak_memory(
                    "Max memory allocated before optimizer step()", rank)
                optimizer.step()
                print_peak_memory(
                    "Max memory allocated after optimizer step()", rank)

                # if profile: profiler.step()

                # total loss summed over each sample in batch
                train_loss += loss.detach() * coefficients.shape[0]
                if rank == 0: progress.update(1)

            scheduler.step()

            # gather total loss during epoch between each GPU worker as list of tensors
            world_loss = [
                torch.ones_like(train_loss) for _ in range(world_size)
            ]
            distributed.all_gather(world_loss, train_loss)
            train_loss *= 0.0  # reset loss for next epoch

            if (interval != 0) and (epoch % interval == 0):
                # evaluate model on validation dataset
                flow.eval()
                with torch.no_grad():

                    iterator = iter(enumerate(val_loader))
                    step, (coefficients, parameters) = next(iterator)
                    coefficients = coefficients.to(rank, non_blocking=True)
                    parameters = parameters.to(rank, non_blocking=True)

                    if rank == 0:
                        val_progress = int(100 * step / len(val_loader))
                        progress.set_description(
                            f'[{log_dir}] Validating ({val_progress}%)',
                            refresh=True)

                    complete = False
                    while not complete:

                        # negative log-likelihood conditional on strain over mini-batch
                        loss = -flow.module.log_prob(
                            parameters, context=coefficients).mean()

                        try:
                            # async get data from CPU and move to GPU during model forward
                            step, (coefficients, parameters) = next(iterator)
                            coefficients = coefficients.to(rank,
                                                           non_blocking=True)
                            parameters = parameters.to(rank, non_blocking=True)

                            if rank == 0:
                                val_progress = int(100 * step /
                                                   len(val_loader))
                                progress.set_description(
                                    f'[{log_dir}] Validating ({val_progress}%)',
                                    refresh=True)

                        except StopIteration:
                            # exit while loop if iterator is complete
                            complete = True

                        # total loss summed over each sample in batch
                        val_loss += loss.detach() * coefficients.shape[0]

                    # gather total loss during epoch between each GPU worker as list of tensors
                    world_val_loss = [
                        torch.ones_like(val_loss) for _ in range(world_size)
                    ]
                    distributed.all_gather(world_val_loss, val_loss)
                    val_loss *= 0.0  # reset loss for next epoch

                    # validation posteriors
                    if rank == 0:
                        progress.set_description(
                            f'[{log_dir}] Sampling posteriors', refresh=True)

                    samples = flows.sample_flow(
                        flow.module,
                        n=10000,
                        context=val_context,
                        output_device='cuda',
                        dtype=torch.float32,
                    )[0]

                    # gather samples from all gpus
                    world_samples = [
                        torch.ones_like(samples) for _ in range(world_size)
                    ]
                    distributed.all_gather(world_samples, samples)

            if (rank == 0):
                progress.set_description(f'[{log_dir}] Sending to TensorBoard',
                                         refresh=True)

                scalars = {
                    'loss/train':
                    torch.cat(world_loss).sum().item() /
                    len(dataloader.dataset)
                }

                # every "interval" we generate samples for vis, else None
                corner_samples = None  # reset to None for epochs where there is no corner plot
                if (interval != 0) and (epoch % interval == 0):

                    scalars['loss/validation'] = torch.cat(
                        world_val_loss).sum().item() / len(val_loader.dataset)

                    # convert gw150914 samples to cpu and undo standardization
                    corner_samples = torch.stack(world_samples).cpu()
                    corner_samples *= torch.from_numpy(val_dataset.std)
                    corner_samples += torch.from_numpy(val_dataset.mean)

                # send data to async process to generate matplotlib figures
                queue.put((epoch, scalars, corner_samples))

                if (save != 0) and (epoch % save == 0):
                    # save checkpoint and write computationally expensive data to tb
                    torch.save(flow.module.state_dict(),
                               experiment_dir / f'flow_{epoch}.pt')

                    # if use_zero:
                    #     # needs to be called on all ranks
                    #     optimizer.consolidate_state_dict(to=0)

                    torch.save(optimizer.state_dict(),
                               experiment_dir / f'optimizer_{epoch}.pt')

                    if scheduler is not None:
                        torch.save(scheduler.state_dict(),
                                   experiment_dir / f'scheduler_{epoch}.pt')

    # destroy processes from distributed training
    if rank == 0:
        # to do - graceful way to shutdown workers
        # need to send message back from child process
        sleep_time = 120
        for i in range(sleep_time):
            progress.set_description(
                f'[{log_dir}] Shutting down in {sleep_time - i}s',
                refresh=True)
            time.sleep(1)

        tb_process.terminate()

    cleanup_nccl()
            def check_optimizer_equivalence(
                    optimizer: Type[torch.optim.Optimizer]):
                # Any model works. Add one different buffer per rank
                model = torch.nn.Sequential(
                    torch.nn.Linear(2, 3),
                    torch.nn.Linear(3, 3),
                    torch.nn.Linear(3, 3),
                )
                model.register_buffer("test_buffer",
                                      torch.ones((1)) * self.rank)
                model.to(self.device)

                sharded_optimizer = ZeroRedundancyOptimizer(
                    params=model.parameters(), optim=optimizer, lr=1e-3)
                sharded_ddp_model = DDP(module=model,
                                        device_ids=[self.rank],
                                        broadcast_buffers=True)

                ddp_model_single = copy.deepcopy(model)
                ddp_model_single.to(self.device)

                ddp_optimizer = optimizer(ddp_model_single.parameters(),
                                          lr=1e-3)
                ddp_model = DDP(ddp_model_single,
                                device_ids=[self.rank],
                                broadcast_buffers=True)

                def check_same_model_params():
                    for pg, ddp_pg in zip(sharded_optimizer.param_groups,
                                          ddp_optimizer.param_groups):
                        for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
                            assert torch.allclose(
                                p, ddp_p, atol=1e-3
                            ), f"Model parameters differ in between Pytorch optim and ZeroRedundancyOptimizer \n{p} {ddp_p}"

                    for b, ddp_b in zip(sharded_ddp_model.buffers(),
                                        ddp_model.buffers()):
                        assert torch.allclose(
                            b, ddp_b
                        ), "Model buffers differ in between Pytorch optim and ZeroRedundancyOptimizer"

                # The model should be synchronized in between the ranks at construction time, check that
                check_same_model_params()

                # The models should stay the same in between the ranks
                for i in range(20):
                    input_tensor = torch.rand((64, 2))

                    def closure_ddp(input_tensor=input_tensor):
                        ddp_optimizer.zero_grad()
                        ddp_loss = ddp_model(input_tensor).abs().sum()
                        ddp_loss.backward()
                        return ddp_loss

                    def closure_sharded(input_tensor=input_tensor):
                        sharded_optimizer.zero_grad()
                        sharded_loss = sharded_ddp_model(
                            input_tensor).abs().sum()
                        sharded_loss.backward()
                        return sharded_loss

                    loss_ddp = cast(torch.Tensor,
                                    ddp_optimizer.step(closure=closure_ddp))
                    loss_sharded_optim = cast(
                        torch.Tensor,
                        sharded_optimizer.step(closure=closure_sharded))

                    assert torch.allclose(
                        loss_ddp, loss_sharded_optim
                    ), "Losses differ in between Pytorch optim and ZeroRedundancyOptimizer"

                    check_same_model_params()