コード例 #1
0
    def test_state_dict(self):
        """Check that the ZeroRedundancyOptimizer exposes the expected state dict interface,
        irrespective of the sharding.
        """
        self.dist_init(self.rank)
        x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
        o = ZeroRedundancyOptimizer([x],
                                    optimizer_class=SGD,
                                    lr=0.1,
                                    momentum=0.9)
        x.backward()
        o.step()
        self.assertEqual(x, torch.tensor([0.9], device=DEVICE))
        self.assertEqual(o.optim.state[x]["momentum_buffer"],
                         torch.tensor([1.0], device=DEVICE))

        o.zero_grad()
        o.consolidate_state_dict(
        )  # Sync state dict in between replicas - even if there are none
        state_dict = o.state_dict()

        # Check that the state dict is pytorch-compliant key wise
        self.assertIn("param_groups", state_dict.keys())
        self.assertIn("state", state_dict.keys())

        # Check that the pulled state is what we expect, and that we have all the expected keys
        self.assertEqual(state_dict["param_groups"][0]["lr"], 0.1)
        self.assertEqual(state_dict["param_groups"][0]["momentum"], 0.9)
        self.assertFalse(state_dict["param_groups"][0]["nesterov"])
        self.assertEqual(state_dict["param_groups"][0]["weight_decay"], 0.0)
        self.assertEqual(state_dict["param_groups"][0]["dampening"], 0.0)

        # Check that the pulled state and the .param_groups attribute are in sync
        for k in state_dict["param_groups"][0].keys():
            if k != "params":
                self.assertEqual(state_dict["param_groups"][0][k],
                                 o.param_groups[0][k])

        # Check that it's correctly loaded
        o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.01)
        o.load_state_dict(state_dict)

        # Check that state is correct and on proper device
        self.assertEqual(o.optim.state[x]["momentum_buffer"],
                         torch.tensor([1.0], device=DEVICE))

        # We should now be using a lr of 0.1, both within the optimizer
        # and as exposed by the .param_groups attribute
        assert o.param_groups[0]["lr"] == 0.1
        x.backward()
        o.step()
        self.assertEqual(x, torch.tensor([0.71], device=DEVICE))
        self.assertEqual(o.optim.state[x]["momentum_buffer"],
                         torch.tensor([1.9], device=DEVICE))

        # Check that the exposed param_groups are on the proper device
        self.assertEqual(o.param_groups[0]["params"][0].device, x.device)
コード例 #2
0
    def test_collect_shards(self):
        """ Check the state consolidation mechanism, and the state dict exposed by ZeroRedundancyOptimizer"""
        self.dist_init(self.rank)
        RECIPIENT_RANK = 0

        # Run a dummy step so that the optimizer state dict exists
        batch, input_width, hidden, target_width = 3, 20, 10, 5
        target = torch.rand((batch, target_width), device=self.device)
        inputs = torch.rand((batch, input_width), device=self.device)

        model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden),
                                    torch.nn.Linear(hidden, target_width))
        model.to(self.device)

        loss_fn = torch.nn.L1Loss()
        loss_fn.to(self.device)

        # With SGD, Momentum is required to get a state to shard
        optimizer = ZeroRedundancyOptimizer(model.parameters(),
                                            optimizer_class=SGD,
                                            lr=0.1,
                                            momentum=0.99)

        def closure():
            optimizer.zero_grad()
            output = model(inputs)
            loss = loss_fn(output, target)
            loss.backward()
            return loss

        _ = optimizer.step(closure=closure)

        # Update the optimizer state on the reference rank
        optimizer.consolidate_state_dict(to=RECIPIENT_RANK)

        # Fetch the state on the reference rank
        # - check that it has the correct size
        # - load it again
        if self.rank == RECIPIENT_RANK:
            optimizer_state_dict = optimizer.state_dict()
            self.assertEqual(len(optimizer_state_dict["state"]),
                             len(list(model.parameters())))
        else:
            optimizer_state_dict = {}

        optimizer_state_dict = _broadcast_object(
            optimizer_state_dict,
            src_rank=RECIPIENT_RANK,
            group=dist.group.WORLD,
            device=self.device,
        )

        # Load the optimizer state dict, check that no exception is raised
        optimizer.load_state_dict(optimizer_state_dict)
コード例 #3
0
    def test_implicit_local_state_dict(self):
        """Check that it's possible to pull a local state dict
        .. warning: probably deprecated in the near future
        """
        self.dist_init(self.rank)

        x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
        o = ZeroRedundancyOptimizer([x], optim=SGD, lr=0.1)
        local_state_dict = o.state_dict()
        o = ZeroRedundancyOptimizer([x], optim=SGD, lr=0.01)
        o.load_state_dict(local_state_dict)
        # We should now be using a lr of 0.1.
        self.assertEqual(o.optim.param_groups[0]["lr"], 0.1)
        self.assertEqual(o.param_groups[0]["lr"], 0.1)
        x.backward()
        o.step()
        self.assertEqual(x, torch.tensor([0.9], device=DEVICE))
コード例 #4
0
            def check_optimizer_equivalence(
                    optimizer: Type[torch.optim.Optimizer]):
                # Any model works. Add one different buffer per rank
                model = torch.nn.Sequential(
                    torch.nn.Linear(2, 3),
                    torch.nn.Linear(3, 3),
                    torch.nn.Linear(3, 3),
                )
                model.register_buffer("test_buffer",
                                      torch.ones((1)) * self.rank)
                model.to(self.device)

                sharded_optimizer = ZeroRedundancyOptimizer(
                    params=model.parameters(),
                    optimizer_class=optimizer,
                    lr=1e-3)
                sharded_ddp_model = DDP(module=model,
                                        device_ids=[self.rank],
                                        broadcast_buffers=True,
                                        find_unused_parameters=True)

                ddp_model_single = copy.deepcopy(model)
                ddp_model_single.to(self.device)

                ddp_optimizer = optimizer(ddp_model_single.parameters(),
                                          lr=1e-3)
                ddp_model = DDP(ddp_model_single,
                                device_ids=[self.rank],
                                broadcast_buffers=True,
                                find_unused_parameters=True)

                # The model should be synchronized in between the ranks at construction time, check that
                check_same_model_params(sharded_ddp_model, ddp_model,
                                        "Models differ from the start")

                def check_step():
                    input_tensor = torch.rand((64, 2))

                    def closure_ddp(input_tensor=input_tensor):
                        ddp_optimizer.zero_grad()
                        ddp_loss = ddp_model(input_tensor).abs().sum()
                        ddp_loss.backward()
                        return ddp_loss

                    def closure_sharded(input_tensor=input_tensor):
                        sharded_optimizer.zero_grad()
                        sharded_loss = sharded_ddp_model(
                            input_tensor).abs().sum()
                        sharded_loss.backward()
                        return sharded_loss

                    loss_ddp = cast(torch.Tensor,
                                    ddp_optimizer.step(closure=closure_ddp))
                    loss_sharded_optim = cast(
                        torch.Tensor,
                        sharded_optimizer.step(closure=closure_sharded))

                    assert torch.allclose(
                        loss_ddp, loss_sharded_optim
                    ), "Losses differ in between Pytorch optim and ZeroRedundancyOptimizer"

                    check_same_model_params(sharded_ddp_model, ddp_model,
                                            "Models differ after a step")

                # The models should stay the same in between the ranks
                for i in range(BATCHS):
                    check_step()

                    # Change the models trainability, check that parity is maintained
                    # only check after a couple of constant batchs to go through both regimes
                    if i > BATCHS // 2:
                        next(ddp_model.parameters()).requires_grad = bool(i %
                                                                          2)
                        next(sharded_ddp_model.parameters()
                             ).requires_grad = bool(i % 2)

                # Check that the checkpoints are compatible
                reference_rank = 0
                # - get states
                ddp_state_dict = ddp_optimizer.state_dict()
                sharded_optimizer.consolidate_state_dict(to=reference_rank)
                sharded_optim_state_dict = [
                    sharded_optimizer.state_dict()
                    if self.rank == reference_rank else {}
                ]
                dist.broadcast_object_list(sharded_optim_state_dict,
                                           src=reference_rank,
                                           group=dist.group.WORLD)
                sharded_optim_state_dict = sharded_optim_state_dict[0]

                # - cross load the states
                # run one step and check that the models are still the same
                ddp_state_dict_ref = copy.deepcopy(
                    ddp_state_dict)  # OSS will remove some states
                ddp_optimizer.load_state_dict(
                    sharded_optim_state_dict)  # mixup on purpose !
                sharded_optimizer.load_state_dict(ddp_state_dict)
                check_step()

                #  - self load, rewind, check no problem
                # run one step and check that the models are still the same
                ddp_optimizer.load_state_dict(ddp_state_dict_ref)
                sharded_optimizer.load_state_dict(sharded_optim_state_dict)
                check_step()
コード例 #5
0
            def check_optimizer_equivalence(
                    optimizer: Type[torch.optim.Optimizer]):
                # Any model works. Add one different buffer per rank
                model = torch.nn.Sequential(
                    torch.nn.Linear(2, 3),
                    torch.nn.Linear(3, 3),
                    torch.nn.Linear(3, 3),
                )
                model.register_buffer("test_buffer",
                                      torch.ones((1)) * self.rank)
                model.to(self.device)

                sharded_optimizer = ZeroRedundancyOptimizer(
                    params=model.parameters(), optim=optimizer, lr=1e-3)
                sharded_ddp_model = DDP(module=model,
                                        device_ids=[self.rank],
                                        broadcast_buffers=True)

                ddp_model_single = copy.deepcopy(model)
                ddp_model_single.to(self.device)

                ddp_optimizer = optimizer(ddp_model_single.parameters(),
                                          lr=1e-3)
                ddp_model = DDP(ddp_model_single,
                                device_ids=[self.rank],
                                broadcast_buffers=True)

                def check_same_model_params():
                    for pg, ddp_pg in zip(sharded_optimizer.param_groups,
                                          ddp_optimizer.param_groups):
                        for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
                            assert torch.allclose(
                                p, ddp_p, atol=1e-3
                            ), f"Model parameters differ in between Pytorch optim and ZeroRedundancyOptimizer \n{p} {ddp_p}"

                    for b, ddp_b in zip(sharded_ddp_model.buffers(),
                                        ddp_model.buffers()):
                        assert torch.allclose(
                            b, ddp_b
                        ), "Model buffers differ in between Pytorch optim and ZeroRedundancyOptimizer"

                # The model should be synchronized in between the ranks at construction time, check that
                check_same_model_params()

                # The models should stay the same across multiple steps, losses should stay the same
                def check_step():
                    input_tensor = torch.rand((64, 2))

                    def closure_ddp(input_tensor=input_tensor):
                        ddp_optimizer.zero_grad()
                        ddp_loss = ddp_model(input_tensor).abs().sum()
                        ddp_loss.backward()
                        return ddp_loss

                    def closure_sharded(input_tensor=input_tensor):
                        sharded_optimizer.zero_grad()
                        sharded_loss = sharded_ddp_model(
                            input_tensor).abs().sum()
                        sharded_loss.backward()
                        return sharded_loss

                    loss_ddp = cast(torch.Tensor,
                                    ddp_optimizer.step(closure=closure_ddp))
                    loss_sharded_optim = cast(
                        torch.Tensor,
                        sharded_optimizer.step(closure=closure_sharded))

                    assert torch.allclose(
                        loss_ddp, loss_sharded_optim
                    ), "Losses differ in between Pytorch optim and ZeroRedundancyOptimizer"

                    check_same_model_params()

                for i in range(20):
                    check_step()

                # Test state dict save/load/equivalence with pytorch
                # - save state for both
                sharded_optimizer.consolidate_state_dict()
                sharded_optimizer_state_dict = (sharded_optimizer.state_dict()
                                                if self.rank == RECIPIENT_RANK
                                                else torch.zeros(1))
                ddp_state_dict = ddp_optimizer.state_dict()

                #  - sync the saved state with all the ranks
                exchange_list = [sharded_optimizer_state_dict]
                dist.broadcast_object_list(
                    exchange_list,
                    src=RECIPIENT_RANK,
                    group=dist.group.WORLD,
                )
                sharded_optimizer_state_dict = exchange_list[0]

                # - cross load the states
                ddp_optimizer.load_state_dict(sharded_optimizer_state_dict)
                sharded_optimizer.load_state_dict(ddp_state_dict)

                # - run one step, and check that the models are still the same
                check_step()
コード例 #6
0
def train(
    rank: int,
    world_size: int,
    lr: float = 5e-4,
    batch_size: int = 1000,
    epochs: int = 500,
    interval: int = 10,
    save: int = 100,
    num_workers: int = 4,
    num_basis: int = 100,
    dataset: str = 'datasets',
    load_dir: Optional[str] = None,
    load_epoch: Optional[int] = None,
    coefficient_noise: bool = False,
    verbose: bool = False,
    use_zero: bool = False,
):
    assert 0 < batch_size, "batch_size must be a positive integer."
    assert 0 < epochs, "epochs must be a positive integer."
    assert (0 <= interval) and (
        interval <= epochs
    ), "Interval must be a non-negative integer between 0 and epochs."
    assert (0 <= save) and (
        save <=
        epochs), "Save must be a non-negative integer between 0 and epochs."

    # setup data distributed parallel training
    setup_nccl(rank, world_size)  # world size is total gpus
    torch.cuda.set_device(rank)  # rank is gpu index

    # directories
    if rank == 0: print(f"Loading data from {dataset}...")
    data_dir = Path(f'/mnt/datahole/daniel/gravflows/{dataset}/train/')
    val_dir = Path(f'/mnt/datahole/daniel/gravflows/{dataset}/validation/')

    noise_dir = Path('/mnt/datahole/daniel/gwosc/O1')
    psd_dir = Path(f"/mnt/datahole/daniel/gravflows/{dataset}/train/PSD/")
    basis_dir = Path(f'/mnt/datahole/daniel/gravflows/{dataset}/basis/')

    log_dir = f"{datetime.now().strftime('%b%d_%H-%M-%S')}_{os.uname().nodename}"

    save_dir = Path('gwpe/model_weights/')
    experiment_dir = save_dir / log_dir
    experiment_dir.mkdir(parents=True, exist_ok=True)

    # config files
    waveform_params_ini = str(data_dir / 'config_files/parameters.ini')
    extrinsics_ini = 'gwpe/config_files/extrinsics.ini'

    static_args_ini = str(data_dir / 'config_files/static_args.ini')

    # validation

    # training data
    # dataset = BasisCoefficientsDataset(
    #     data_dir=data_dir,
    #     basis_dir=basis_dir,
    #     static_args_ini=static_args_ini,
    #     parameters_ini=waveform_params_ini,
    # )

    # dataset = BasisEncoderDataset(
    #     n=num_basis,
    #     data_dir=data_dir,
    #     basis_dir=basis_dir,
    #     static_args_ini=static_args_ini,
    #     intrinsics_ini=waveform_params_ini,
    #     extrinsics_ini=extrinsics_ini,
    #     psd_dir=psd_dir,
    #     ifos=['H1','L1'],
    #     ref_ifo='H1',
    #     downcast=True,
    #     add_noise=True,
    #     coefficient_noise=coefficient_noise,
    # )

    dataset = LFIGWDataset(
        n=100,
        data_dir=data_dir,
        basis_dir=basis_dir,
        static_args_ini=static_args_ini,
        data_file='coefficients.npy',
        intrinsics_ini=waveform_params_ini,
        extrinsics_ini=extrinsics_ini,
        psd_dir=psd_dir,
        ifos=['H1', 'L1'],
        ref_ifo='H1',
        downcast=True,
        add_noise=True,
        distance_scale=True,
        time_shift=False,
    )

    sampler = DistributedSampler(
        dataset,
        shuffle=False,
        num_replicas=world_size,
        rank=rank,
        seed=rank,
    )

    dataloader = DataLoader(
        dataset,
        shuffle=False,
        num_workers=num_workers,
        batch_size=batch_size,
        sampler=sampler,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2,
        worker_init_fn=dataset._worker_init_fn,
        collate_fn=dataset._collate_fn,
    )

    # validation data
    val_dataset = LFIGWDataset(
        n=100,
        data_dir=data_dir,
        basis_dir=basis_dir,
        static_args_ini=static_args_ini,
        data_file='coefficients.npy',
        intrinsics_ini=waveform_params_ini,
        extrinsics_ini=extrinsics_ini,
        psd_dir=psd_dir,
        ifos=['H1', 'L1'],
        ref_ifo='H1',
        downcast=True,
        add_noise=True,
        coefficient_noise=coefficient_noise,
        distance_scale=True,
        time_shift=False,
    )

    # val_dataset = BasisCoefficientsDataset(
    #     data_dir=val_dir,
    #     basis_dir=basis_dir,
    #     static_args_ini=static_args_ini,
    #     parameters_ini=[waveform_params_ini, extrinsics_ini],
    #     coefficient_noise=coefficient_noise,
    # )

    val_sampler = DistributedSampler(
        val_dataset,
        shuffle=False,
        num_replicas=world_size,
        rank=rank,
        seed=rank,
    )

    val_loader = DataLoader(
        val_dataset,
        shuffle=False,
        num_workers=num_workers,
        batch_size=batch_size,
        sampler=val_sampler,
        pin_memory=True,
        prefetch_factor=4,
        worker_init_fn=val_dataset._worker_init_fn,
        collate_fn=val_dataset._collate_fn,
    )

    # validation data
    if interval != 0:

        # specify indices in validation dataset to validate samples
        min_idx = val_dataset.parameters.distance.argmin()
        max_idx = val_dataset.parameters.distance.argmax()
        median_idx = val_dataset.parameters.loc[
            val_dataset.parameters.distance == val_dataset.parameters.distance.
            quantile(interpolation='nearest')].index[0]

        if rank == 0:
            figure_titles = [
                'GW150914', 'Min Distance', f'Median Distance', f'Max Distance'
            ]

            # validation ground truths for posterior sampling
            val_gts = torch.stack([
                torch.zeros(len(val_dataset.parameters.columns),
                            dtype=torch.float32),  # gw150914 dummy gt
                torch.tensor(val_dataset.parameters.iloc[min_idx].values,
                             dtype=torch.float32),  # rank 1
                torch.tensor(val_dataset.parameters.iloc[median_idx].values,
                             dtype=torch.float32),  # rank 2
                torch.tensor(val_dataset.parameters.iloc[max_idx].values,
                             dtype=torch.float32),  # rank 3
            ])

        with torch.no_grad():
            # load data from file manually (rather than using val_dataset._worker_init_fn)
            val_coefficients = np.load(val_dataset.data_dir /
                                       val_dataset.data_file,
                                       mmap_mode='c')

            # generate coefficients on cpu - we want to send this to tensorboard (rank 0) before sending to gpus
            val_coefficients = torch.cat([
                torch.from_numpy(
                    generate_gw150914_context(num_basis, noise_dir, psd_dir,
                                              basis_dir,
                                              static_args_ini))[None],
                torch.tensor(val_coefficients[[min_idx, median_idx, max_idx]]),
            ],
                                         dim=0).to(dtype=torch.complex64)

            # place one of each stacked tensor onto corresponding gpu rank
            val_context = val_coefficients[
                rank] * val_dataset.standardization[:, :num_basis]
            val_context = val_context.to(device=rank)
            val_context = torch.cat([val_context.real, val_context.imag],
                                    dim=0)
            val_context = val_context.reshape(val_context.shape[0] *
                                              val_context.shape[1])[None]

    else:
        figure_titles = None
        val_gts = None
        val_coefficients = None

    # set torch profiling runs
    # wait = 1  # ignore first batch
    # warmup = 1
    # active = 4
    # repeat = 2

    # tensorboard
    if rank == 0:
        # tb = SummaryWriter(f'gwpe/runs/{log_dir}')
        queue = mp.SimpleQueue()
        tb_process = mp.Process(target=tensorboard_writer,
                                args=(
                                    queue,
                                    f'gwpe/runs/{log_dir}',
                                    val_dataset.generator.parameters,
                                    val_dataset.generator.latex,
                                    static_args_ini,
                                    basis_dir,
                                    num_basis,
                                    val_coefficients,
                                    val_gts,
                                    figure_titles,
                                ))
        tb_process.start()

    # instantiate neural spline coupling flow
    flow = flows.create_NDE_model(
        input_dim=14,  # we do not predict coalescence time 
        context_dim=4 * num_basis,
        num_flow_steps=15,
        base_transform_kwargs={
            'base_transform_type': 'rq-coupling',
            'batch_norm': True,
            'num_transform_blocks': 10,
            'activation': 'elu',
        })

    flow = flow.to(rank)
    print_peak_memory("Max memory allocated after creating local model", rank)

    # sync_bn_flow = nn.SyncBatchNorm.convert_sync_batchnorm(flow)
    flow = DDP(flow, device_ids=[rank], output_device=rank)

    print_peak_memory("Max memory allocated after creating DDP", rank)

    if use_zero:
        #https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html
        from torch.distributed.optim import ZeroRedundancyOptimizer
        optimizer = ZeroRedundancyOptimizer(
            flow.parameters(),
            optimizer_class=torch.optim.Adam,
            lr=lr,
            parameters_as_bucket_view=True,
        )
        # optimizer = torch.optim.Adam(flow.parameters(), lr=lr)
    else:
        optimizer = torch.optim.Adam(flow.parameters(), lr=lr)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           T_max=epochs)

    if load_dir is not None and load_epoch is not None:
        print(f'Loading model from {load_dir} at epoch {load_epoch}.')
        flow.module.load_state_dict(
            torch.load(f'gwpe/model_weights/{load_dir}/flow_{load_epoch}.pt',
                       map_location=rank))
        optimizer.load_state_dict(
            torch.load(
                f'gwpe/model_weights/{load_dir}/optimizer_{load_epoch}.pt',
                map_location=rank))
        if Path(f'gwpe/model_weights/{load_dir}/scheduler_{load_epoch}.pt'
                ).is_file():
            scheduler.load_state_dict(
                torch.load(
                    f'gwpe/model_weights/{load_dir}/scheduler_{load_epoch}.pt',
                    map_location=rank))

    # run training loop
    flow.train()
    train_loss = torch.zeros((1, ), device=rank, requires_grad=False)
    val_loss = torch.zeros((1, ), device=rank, requires_grad=False)

    disable_pbar = False if verbose and (rank
                                         == 0) else True  # tqdm progress bar
    with tqdm(total=len(dataloader) * epochs,
              disable=disable_pbar,
              desc=f'[{log_dir}] Training',
              postfix={'epoch': 0}) as progress:
        # with torch.profiler.profile(
        #     activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
        #     schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat),
        #     on_trace_ready=torch.profiler.tensorboard_trace_handler(f'gwpe/runs/{log_dir}'),
        #     record_shapes=True,
        #     with_stack=True
        # ) as profiler:

        for epoch in range(1, 1 + epochs):
            if rank == 0:
                progress.set_postfix({'epoch': epoch})
                progress.set_description(f'[{log_dir}] Training', refresh=True)

            # let all processes sync up before starting with a new epoch of training
            flow.train()
            distributed.barrier()

            iterator = iter(dataloader)
            coefficients, parameters = next(iterator)

            coefficients = coefficients.to(rank, non_blocking=True)
            parameters = parameters.to(rank, non_blocking=True)

            complete = False
            while not complete:
                optimizer.zero_grad()

                # if profile:
                # https://github.com/guyang3532/kineto/blob/readme/tb_plugin/docs/gpu_utilization.md
                ## WARNING: profiler may not handle async pinned memory transfer properly?
                # i.e. may not record CPU vs GPU wall times correctly
                # may be related to reported blocks per SM/achieved occupancy negative bug
                # this was an open issue for pytorch 1.9 as of july 9 - nightly may fix it
                # https://github.com/pytorch/kineto/issues/325#issuecomment-869362218
                # if (step >= (wait + warmup + active) * repeat):
                #     break

                # negative log-likelihood conditional on strain over mini-batch
                loss = -flow.module.log_prob(parameters,
                                             context=coefficients).mean()

                try:
                    # async get data from CPU and move to GPU during model forward
                    coefficients, parameters = next(iterator)

                    coefficients = coefficients.to(rank, non_blocking=True)
                    parameters = parameters.to(rank, non_blocking=True)

                except StopIteration:
                    # exit while loop if iterator is complete
                    complete = True

                loss.backward()

                print_peak_memory(
                    "Max memory allocated before optimizer step()", rank)
                optimizer.step()
                print_peak_memory(
                    "Max memory allocated after optimizer step()", rank)

                # if profile: profiler.step()

                # total loss summed over each sample in batch
                train_loss += loss.detach() * coefficients.shape[0]
                if rank == 0: progress.update(1)

            scheduler.step()

            # gather total loss during epoch between each GPU worker as list of tensors
            world_loss = [
                torch.ones_like(train_loss) for _ in range(world_size)
            ]
            distributed.all_gather(world_loss, train_loss)
            train_loss *= 0.0  # reset loss for next epoch

            if (interval != 0) and (epoch % interval == 0):
                # evaluate model on validation dataset
                flow.eval()
                with torch.no_grad():

                    iterator = iter(enumerate(val_loader))
                    step, (coefficients, parameters) = next(iterator)
                    coefficients = coefficients.to(rank, non_blocking=True)
                    parameters = parameters.to(rank, non_blocking=True)

                    if rank == 0:
                        val_progress = int(100 * step / len(val_loader))
                        progress.set_description(
                            f'[{log_dir}] Validating ({val_progress}%)',
                            refresh=True)

                    complete = False
                    while not complete:

                        # negative log-likelihood conditional on strain over mini-batch
                        loss = -flow.module.log_prob(
                            parameters, context=coefficients).mean()

                        try:
                            # async get data from CPU and move to GPU during model forward
                            step, (coefficients, parameters) = next(iterator)
                            coefficients = coefficients.to(rank,
                                                           non_blocking=True)
                            parameters = parameters.to(rank, non_blocking=True)

                            if rank == 0:
                                val_progress = int(100 * step /
                                                   len(val_loader))
                                progress.set_description(
                                    f'[{log_dir}] Validating ({val_progress}%)',
                                    refresh=True)

                        except StopIteration:
                            # exit while loop if iterator is complete
                            complete = True

                        # total loss summed over each sample in batch
                        val_loss += loss.detach() * coefficients.shape[0]

                    # gather total loss during epoch between each GPU worker as list of tensors
                    world_val_loss = [
                        torch.ones_like(val_loss) for _ in range(world_size)
                    ]
                    distributed.all_gather(world_val_loss, val_loss)
                    val_loss *= 0.0  # reset loss for next epoch

                    # validation posteriors
                    if rank == 0:
                        progress.set_description(
                            f'[{log_dir}] Sampling posteriors', refresh=True)

                    samples = flows.sample_flow(
                        flow.module,
                        n=10000,
                        context=val_context,
                        output_device='cuda',
                        dtype=torch.float32,
                    )[0]

                    # gather samples from all gpus
                    world_samples = [
                        torch.ones_like(samples) for _ in range(world_size)
                    ]
                    distributed.all_gather(world_samples, samples)

            if (rank == 0):
                progress.set_description(f'[{log_dir}] Sending to TensorBoard',
                                         refresh=True)

                scalars = {
                    'loss/train':
                    torch.cat(world_loss).sum().item() /
                    len(dataloader.dataset)
                }

                # every "interval" we generate samples for vis, else None
                corner_samples = None  # reset to None for epochs where there is no corner plot
                if (interval != 0) and (epoch % interval == 0):

                    scalars['loss/validation'] = torch.cat(
                        world_val_loss).sum().item() / len(val_loader.dataset)

                    # convert gw150914 samples to cpu and undo standardization
                    corner_samples = torch.stack(world_samples).cpu()
                    corner_samples *= torch.from_numpy(val_dataset.std)
                    corner_samples += torch.from_numpy(val_dataset.mean)

                # send data to async process to generate matplotlib figures
                queue.put((epoch, scalars, corner_samples))

                if (save != 0) and (epoch % save == 0):
                    # save checkpoint and write computationally expensive data to tb
                    torch.save(flow.module.state_dict(),
                               experiment_dir / f'flow_{epoch}.pt')

                    # if use_zero:
                    #     # needs to be called on all ranks
                    #     optimizer.consolidate_state_dict(to=0)

                    torch.save(optimizer.state_dict(),
                               experiment_dir / f'optimizer_{epoch}.pt')

                    if scheduler is not None:
                        torch.save(scheduler.state_dict(),
                                   experiment_dir / f'scheduler_{epoch}.pt')

    # destroy processes from distributed training
    if rank == 0:
        # to do - graceful way to shutdown workers
        # need to send message back from child process
        sleep_time = 120
        for i in range(sleep_time):
            progress.set_description(
                f'[{log_dir}] Shutting down in {sleep_time - i}s',
                refresh=True)
            time.sleep(1)

        tb_process.terminate()

    cleanup_nccl()
コード例 #7
0
            def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer]):
                # Any model works. Add one different buffer per rank
                model = torch.nn.Sequential(
                    torch.nn.Linear(2, 3),
                    torch.nn.Linear(3, 3),
                    torch.nn.Linear(3, 3),
                )
                model.register_buffer("test_buffer", torch.ones((1)) * self.rank)
                model.to(self.device)

                sharded_optimizer = ZeroRedundancyOptimizer(params=model.parameters(), optim=optimizer, lr=1e-3)
                sharded_ddp_model = DDP(module=model, device_ids=[self.rank], broadcast_buffers=True)

                ddp_model_single = copy.deepcopy(model)
                ddp_model_single.to(self.device)

                ddp_optimizer = optimizer(ddp_model_single.parameters(), lr=1e-3)
                ddp_model = DDP(ddp_model_single, device_ids=[self.rank], broadcast_buffers=True)

                def check_same_model_params():
                    for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
                        for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
                            assert torch.allclose(
                                p, ddp_p, atol=1e-3
                            ), f"Model parameters differ in between Pytorch optim and ZeroRedundancyOptimizer \n{p} {ddp_p}"

                    for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
                        assert torch.allclose(
                            b, ddp_b
                        ), "Model buffers differ in between Pytorch optim and ZeroRedundancyOptimizer"

                # The model should be synchronized in between the ranks at construction time, check that
                check_same_model_params()

                def check_step():
                    input_tensor = torch.rand((64, 2))

                    def closure_ddp(input_tensor=input_tensor):
                        ddp_optimizer.zero_grad()
                        ddp_loss = ddp_model(input_tensor).abs().sum()
                        ddp_loss.backward()
                        return ddp_loss

                    def closure_sharded(input_tensor=input_tensor):
                        sharded_optimizer.zero_grad()
                        sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
                        sharded_loss.backward()
                        return sharded_loss

                    loss_ddp = cast(torch.Tensor, ddp_optimizer.step(closure=closure_ddp))
                    loss_sharded_optim = cast(torch.Tensor, sharded_optimizer.step(closure=closure_sharded))

                    assert torch.allclose(
                        loss_ddp, loss_sharded_optim
                    ), "Losses differ in between Pytorch optim and ZeroRedundancyOptimizer"

                    check_same_model_params()

                # The models should stay the same in between the ranks
                for i in range(20):
                    check_step()

                # Check that the checkpoints are compatible
                reference_rank = 0
                # - get states
                ddp_state_dict = ddp_optimizer.state_dict()
                sharded_optimizer.consolidate_state_dict(recipient_rank=reference_rank)
                sharded_optim_state_dict = [sharded_optimizer.state_dict() if self.rank == reference_rank else {}]
                dist.broadcast_object_list(sharded_optim_state_dict, src=reference_rank, group=dist.group.WORLD)
                sharded_optim_state_dict = sharded_optim_state_dict[0]

                # - cross load the states
                # run one step and check that the models are still the same
                ddp_state_dict_ref = copy.deepcopy(ddp_state_dict)  # OSS will remove some states
                ddp_optimizer.load_state_dict(sharded_optim_state_dict)  # mixup on purpose !
                sharded_optimizer.load_state_dict(ddp_state_dict)
                check_step()

                #  - self load, rewind, check no problem
                # run one step and check that the models are still the same
                ddp_optimizer.load_state_dict(ddp_state_dict_ref)
                sharded_optimizer.load_state_dict(sharded_optim_state_dict)
                check_step()