示例#1
0
    def test_state_dict(self):
        """Check that the ZeroRedundancyOptimizer exposes the expected state dict interface,
        irrespective of the sharding.
        """
        self.dist_init(self.rank)
        x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
        o = ZeroRedundancyOptimizer([x],
                                    optimizer_class=SGD,
                                    lr=0.1,
                                    momentum=0.9)
        x.backward()
        o.step()
        self.assertEqual(x, torch.tensor([0.9], device=DEVICE))
        self.assertEqual(o.optim.state[x]["momentum_buffer"],
                         torch.tensor([1.0], device=DEVICE))

        o.zero_grad()
        o.consolidate_state_dict(
        )  # Sync state dict in between replicas - even if there are none
        state_dict = o.state_dict()

        # Check that the state dict is pytorch-compliant key wise
        self.assertIn("param_groups", state_dict.keys())
        self.assertIn("state", state_dict.keys())

        # Check that the pulled state is what we expect, and that we have all the expected keys
        self.assertEqual(state_dict["param_groups"][0]["lr"], 0.1)
        self.assertEqual(state_dict["param_groups"][0]["momentum"], 0.9)
        self.assertFalse(state_dict["param_groups"][0]["nesterov"])
        self.assertEqual(state_dict["param_groups"][0]["weight_decay"], 0.0)
        self.assertEqual(state_dict["param_groups"][0]["dampening"], 0.0)

        # Check that the pulled state and the .param_groups attribute are in sync
        for k in state_dict["param_groups"][0].keys():
            if k != "params":
                self.assertEqual(state_dict["param_groups"][0][k],
                                 o.param_groups[0][k])

        # Check that it's correctly loaded
        o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.01)
        o.load_state_dict(state_dict)

        # Check that state is correct and on proper device
        self.assertEqual(o.optim.state[x]["momentum_buffer"],
                         torch.tensor([1.0], device=DEVICE))

        # We should now be using a lr of 0.1, both within the optimizer
        # and as exposed by the .param_groups attribute
        assert o.param_groups[0]["lr"] == 0.1
        x.backward()
        o.step()
        self.assertEqual(x, torch.tensor([0.71], device=DEVICE))
        self.assertEqual(o.optim.state[x]["momentum_buffer"],
                         torch.tensor([1.9], device=DEVICE))

        # Check that the exposed param_groups are on the proper device
        self.assertEqual(o.param_groups[0]["params"][0].device, x.device)
 def test_zero_grad(self):
     """Check that the zero_grad attribute is properly handled"""
     self.dist_init(self.rank)
     x = torch.rand(1)
     m = torch.nn.Linear(1, 1)
     o = ZeroRedundancyOptimizer(m.parameters(), optimizer_class=SGD, lr=0.1)
     y = m(x)
     y.backward(x)
     self.assertNotEqual(m.weight.grad, torch.zeros_like(m.weight))
     self.assertNotEqual(m.weight.grad, torch.zeros_like(m.weight))
     o.zero_grad()
     self.assertFalse(m.weight.grad)
     self.assertFalse(m.bias.grad)
示例#3
0
    def test_lr_scheduler(self):
        """ Check that a normal torch lr_scheduler is usable with ZeroRedundancyOptimizer"""

        self.dist_init(self.rank)
        x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
        x2 = torch.tensor([1.0], device=DEVICE, requires_grad=True)
        o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.01)
        o2 = torch.optim.SGD([x2], lr=0.01)
        s = torch.optim.lr_scheduler.StepLR(o, 1)
        s2 = torch.optim.lr_scheduler.StepLR(o2, 1)
        for _ in range(5):
            x.backward()
            o.zero_grad()
            o.step()
            s.step()
            x2.backward()
            o2.zero_grad()
            o2.step()
            s2.step()
            self.assertEqual(x, x2)
示例#4
0
    def _test_ddp_zero_overlap(self, device, hook_constructor,
                               gradient_as_bucket_view, static_graph):
        SGD_LR = 0.01
        SGD_MOMENTUM = 0.9
        SGD_WEIGHT_DECAY = 0.001
        NUM_INPUTS = 5
        torch.manual_seed(0)
        torch.cuda.manual_seed(0)

        rank = self.rank
        is_gpu = device.type == "cuda"
        if BACKEND == dist.Backend.NCCL and is_gpu:
            torch.cuda.set_device(device)
        models_to_test = [
            (torch.nn.Sequential(torch.nn.Linear(1000, 2000),
                                 torch.nn.Linear(2000, 500)),
             [torch.randn(1, 1000).to(device) for _ in range(NUM_INPUTS)]),
        ]
        if HAS_TORCHVISION:
            models_to_test.append((torchvision.models.resnet50(), [
                torch.randn(1, 3, 3, 1000).to(device)
                for _ in range(NUM_INPUTS)
            ]))
        for (model, inputs) in models_to_test:
            # Enable determinism in cudnn operators
            with torch.backends.cudnn.flags(enabled=True,
                                            deterministic=True,
                                            benchmark=False):
                device_ids = [rank] if is_gpu else None
                # Set up the DDP model overlapping with ZeRO
                ddp_model_overlap = DDP(
                    copy.deepcopy(model).to(device),
                    device_ids=device_ids,
                    gradient_as_bucket_view=gradient_as_bucket_view)
                if static_graph:
                    ddp_model_overlap._set_static_graph()
                zero_optim = ZeroRedundancyOptimizer(
                    ddp_model_overlap.parameters(),
                    optimizer_class=torch.optim.SGD,
                    overlap_with_ddp=True,
                    lr=SGD_LR,
                    momentum=SGD_MOMENTUM,
                    weight_decay=SGD_WEIGHT_DECAY,
                    _allow_empty_param_list=True)
                ddp_model_overlap.register_comm_hook(
                    None,
                    hook_constructor(allreduce_hook, ddp_model_overlap,
                                     zero_optim))

                # Set up the DDP model with local optimizer
                ddp_model_local = DDP(
                    copy.deepcopy(model).to(device),
                    device_ids=device_ids,
                    gradient_as_bucket_view=gradient_as_bucket_view)
                if static_graph:
                    ddp_model_local._set_static_graph()
                local_optim = torch.optim.SGD(ddp_model_local.parameters(),
                                              lr=SGD_LR,
                                              momentum=SGD_MOMENTUM,
                                              weight_decay=SGD_WEIGHT_DECAY)

                # Check that the parameters match initially
                for p1, p2 in zip(ddp_model_overlap.parameters(),
                                  ddp_model_local.parameters()):
                    self.assertEqual(p1, p2)

                # Save the parameters to ensure they were updated
                init_params_overlap = copy.deepcopy(
                    list(ddp_model_overlap.parameters()))

                # Ensure that this test runs independently
                dist.barrier()

                # Run the DDP model overlapping with ZeRO
                # NOTE: Overlapping currently requires 2 or 3 warmup iterations
                # to ensure DDP buckets have been rebuilt (depending on the
                # value of `static_graph`)
                num_warmup_inputs = 2 if not static_graph else 3
                for input in inputs[:num_warmup_inputs]:
                    output = ddp_model_overlap(input)
                    loss = output.sum()
                    loss.backward()
                    zero_optim.step()
                for input in inputs:
                    zero_optim.zero_grad()
                    output = ddp_model_overlap(input)
                    loss = output.sum()
                    loss.backward()
                    zero_optim.step()

                # Run the DDP model with local optimizer
                for input in inputs:
                    local_optim.zero_grad()
                    output = ddp_model_local(input)
                    loss = output.sum()
                    loss.backward()
                    local_optim.step()
                dist.barrier()

                # Check that the parameters are equal
                for p1, p2 in zip(ddp_model_overlap.parameters(),
                                  ddp_model_local.parameters()):
                    self.assertEqual(p1, p2)

                # Check that the parameters were updated
                self.assertNotEqual(init_params_overlap,
                                    list(ddp_model_overlap.parameters()))

                # Ensure that this test runs independently
                dist.barrier()
示例#5
0
文件: ddp.py 项目: c3se/alvis-intro
def run_process():
    '''Run process

    This is what is actually run on each process.
    '''
    # Get distributed parameters
    rank = dist.get_rank()
    local_rank = int(os.environ["LOCAL_RANK"])
    world_size = dist.get_world_size()

    # Initialize data_loader
    context_size = 512
    batch_size = 32
    corpus_length = 1024
    vocab_size = 2**8

    dataset = RandomCorpus(corpus_length, context_size, vocab_size)
    sampler = DistributedSampler(dataset, shuffle=True, drop_last=False)
    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        sampler=sampler,
    )

    # Initialize model
    model = GPT(vocab_size, context_size, verbose=True)

    device = torch.device(f"cuda:{local_rank}")
    model.to(device)

    # Prepare for distributed data parallelism
    model = DistributedDataParallel(model,
                                    device_ids=[rank],
                                    output_device=rank)

    # The learning rate is adapted for the total batch_size in tokens
    learning_rate = 6e-4 * (batch_size * world_size * context_size / 5e5)
    # ZeroRedundancyOptimizer reduces the memory footprint of the Optimizer
    opt = ZeroRedundancyOptimizer(
        model.parameters(),
        optimizer_class=optim.Adam,
        lr=learning_rate,
    )
    loss_func = nn.CrossEntropyLoss()

    # Initialize logger instance to see performance
    writer = BenchmarkWriter()

    # Actual training
    global_step = 0
    n_epochs = 10
    for epoch in range(n_epochs):
        model.train()
        sampler.set_epoch(epoch)  # for correct shuffling
        for sequence, in data_loader:
            opt.zero_grad()

            # Shift so that prediction is next token for each token
            sequence = sequence.to(device)
            logits = model(sequence[..., :-1].contiguous())
            target = sequence[..., 1:].contiguous()

            # Flatten the tokens when calculating loss
            loss = loss_func(
                logits.flatten(end_dim=-2),
                target.flatten(),
            )
            loss.backward()
            opt.step()

            # This will also log the wall time
            if rank == 0:
                global_step += batch_size * world_size
                writer.add_scalar("Loss", loss.item(), global_step=global_step)

        if rank == 0:
            print("Epoch:", epoch)

    if rank == 0:
        writer.benchmark_results(burn_in_steps=2 * corpus_length,
                                 step_unit="seq")
    writer.close()

    return model
示例#6
0
def train(
    rank: int,
    world_size: int,
    lr: float = 5e-4,
    batch_size: int = 1000,
    epochs: int = 500,
    interval: int = 10,
    save: int = 100,
    num_workers: int = 4,
    num_basis: int = 100,
    dataset: str = 'datasets',
    load_dir: Optional[str] = None,
    load_epoch: Optional[int] = None,
    coefficient_noise: bool = False,
    verbose: bool = False,
    use_zero: bool = False,
):
    assert 0 < batch_size, "batch_size must be a positive integer."
    assert 0 < epochs, "epochs must be a positive integer."
    assert (0 <= interval) and (
        interval <= epochs
    ), "Interval must be a non-negative integer between 0 and epochs."
    assert (0 <= save) and (
        save <=
        epochs), "Save must be a non-negative integer between 0 and epochs."

    # setup data distributed parallel training
    setup_nccl(rank, world_size)  # world size is total gpus
    torch.cuda.set_device(rank)  # rank is gpu index

    # directories
    if rank == 0: print(f"Loading data from {dataset}...")
    data_dir = Path(f'/mnt/datahole/daniel/gravflows/{dataset}/train/')
    val_dir = Path(f'/mnt/datahole/daniel/gravflows/{dataset}/validation/')

    noise_dir = Path('/mnt/datahole/daniel/gwosc/O1')
    psd_dir = Path(f"/mnt/datahole/daniel/gravflows/{dataset}/train/PSD/")
    basis_dir = Path(f'/mnt/datahole/daniel/gravflows/{dataset}/basis/')

    log_dir = f"{datetime.now().strftime('%b%d_%H-%M-%S')}_{os.uname().nodename}"

    save_dir = Path('gwpe/model_weights/')
    experiment_dir = save_dir / log_dir
    experiment_dir.mkdir(parents=True, exist_ok=True)

    # config files
    waveform_params_ini = str(data_dir / 'config_files/parameters.ini')
    extrinsics_ini = 'gwpe/config_files/extrinsics.ini'

    static_args_ini = str(data_dir / 'config_files/static_args.ini')

    # validation

    # training data
    # dataset = BasisCoefficientsDataset(
    #     data_dir=data_dir,
    #     basis_dir=basis_dir,
    #     static_args_ini=static_args_ini,
    #     parameters_ini=waveform_params_ini,
    # )

    # dataset = BasisEncoderDataset(
    #     n=num_basis,
    #     data_dir=data_dir,
    #     basis_dir=basis_dir,
    #     static_args_ini=static_args_ini,
    #     intrinsics_ini=waveform_params_ini,
    #     extrinsics_ini=extrinsics_ini,
    #     psd_dir=psd_dir,
    #     ifos=['H1','L1'],
    #     ref_ifo='H1',
    #     downcast=True,
    #     add_noise=True,
    #     coefficient_noise=coefficient_noise,
    # )

    dataset = LFIGWDataset(
        n=100,
        data_dir=data_dir,
        basis_dir=basis_dir,
        static_args_ini=static_args_ini,
        data_file='coefficients.npy',
        intrinsics_ini=waveform_params_ini,
        extrinsics_ini=extrinsics_ini,
        psd_dir=psd_dir,
        ifos=['H1', 'L1'],
        ref_ifo='H1',
        downcast=True,
        add_noise=True,
        distance_scale=True,
        time_shift=False,
    )

    sampler = DistributedSampler(
        dataset,
        shuffle=False,
        num_replicas=world_size,
        rank=rank,
        seed=rank,
    )

    dataloader = DataLoader(
        dataset,
        shuffle=False,
        num_workers=num_workers,
        batch_size=batch_size,
        sampler=sampler,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2,
        worker_init_fn=dataset._worker_init_fn,
        collate_fn=dataset._collate_fn,
    )

    # validation data
    val_dataset = LFIGWDataset(
        n=100,
        data_dir=data_dir,
        basis_dir=basis_dir,
        static_args_ini=static_args_ini,
        data_file='coefficients.npy',
        intrinsics_ini=waveform_params_ini,
        extrinsics_ini=extrinsics_ini,
        psd_dir=psd_dir,
        ifos=['H1', 'L1'],
        ref_ifo='H1',
        downcast=True,
        add_noise=True,
        coefficient_noise=coefficient_noise,
        distance_scale=True,
        time_shift=False,
    )

    # val_dataset = BasisCoefficientsDataset(
    #     data_dir=val_dir,
    #     basis_dir=basis_dir,
    #     static_args_ini=static_args_ini,
    #     parameters_ini=[waveform_params_ini, extrinsics_ini],
    #     coefficient_noise=coefficient_noise,
    # )

    val_sampler = DistributedSampler(
        val_dataset,
        shuffle=False,
        num_replicas=world_size,
        rank=rank,
        seed=rank,
    )

    val_loader = DataLoader(
        val_dataset,
        shuffle=False,
        num_workers=num_workers,
        batch_size=batch_size,
        sampler=val_sampler,
        pin_memory=True,
        prefetch_factor=4,
        worker_init_fn=val_dataset._worker_init_fn,
        collate_fn=val_dataset._collate_fn,
    )

    # validation data
    if interval != 0:

        # specify indices in validation dataset to validate samples
        min_idx = val_dataset.parameters.distance.argmin()
        max_idx = val_dataset.parameters.distance.argmax()
        median_idx = val_dataset.parameters.loc[
            val_dataset.parameters.distance == val_dataset.parameters.distance.
            quantile(interpolation='nearest')].index[0]

        if rank == 0:
            figure_titles = [
                'GW150914', 'Min Distance', f'Median Distance', f'Max Distance'
            ]

            # validation ground truths for posterior sampling
            val_gts = torch.stack([
                torch.zeros(len(val_dataset.parameters.columns),
                            dtype=torch.float32),  # gw150914 dummy gt
                torch.tensor(val_dataset.parameters.iloc[min_idx].values,
                             dtype=torch.float32),  # rank 1
                torch.tensor(val_dataset.parameters.iloc[median_idx].values,
                             dtype=torch.float32),  # rank 2
                torch.tensor(val_dataset.parameters.iloc[max_idx].values,
                             dtype=torch.float32),  # rank 3
            ])

        with torch.no_grad():
            # load data from file manually (rather than using val_dataset._worker_init_fn)
            val_coefficients = np.load(val_dataset.data_dir /
                                       val_dataset.data_file,
                                       mmap_mode='c')

            # generate coefficients on cpu - we want to send this to tensorboard (rank 0) before sending to gpus
            val_coefficients = torch.cat([
                torch.from_numpy(
                    generate_gw150914_context(num_basis, noise_dir, psd_dir,
                                              basis_dir,
                                              static_args_ini))[None],
                torch.tensor(val_coefficients[[min_idx, median_idx, max_idx]]),
            ],
                                         dim=0).to(dtype=torch.complex64)

            # place one of each stacked tensor onto corresponding gpu rank
            val_context = val_coefficients[
                rank] * val_dataset.standardization[:, :num_basis]
            val_context = val_context.to(device=rank)
            val_context = torch.cat([val_context.real, val_context.imag],
                                    dim=0)
            val_context = val_context.reshape(val_context.shape[0] *
                                              val_context.shape[1])[None]

    else:
        figure_titles = None
        val_gts = None
        val_coefficients = None

    # set torch profiling runs
    # wait = 1  # ignore first batch
    # warmup = 1
    # active = 4
    # repeat = 2

    # tensorboard
    if rank == 0:
        # tb = SummaryWriter(f'gwpe/runs/{log_dir}')
        queue = mp.SimpleQueue()
        tb_process = mp.Process(target=tensorboard_writer,
                                args=(
                                    queue,
                                    f'gwpe/runs/{log_dir}',
                                    val_dataset.generator.parameters,
                                    val_dataset.generator.latex,
                                    static_args_ini,
                                    basis_dir,
                                    num_basis,
                                    val_coefficients,
                                    val_gts,
                                    figure_titles,
                                ))
        tb_process.start()

    # instantiate neural spline coupling flow
    flow = flows.create_NDE_model(
        input_dim=14,  # we do not predict coalescence time 
        context_dim=4 * num_basis,
        num_flow_steps=15,
        base_transform_kwargs={
            'base_transform_type': 'rq-coupling',
            'batch_norm': True,
            'num_transform_blocks': 10,
            'activation': 'elu',
        })

    flow = flow.to(rank)
    print_peak_memory("Max memory allocated after creating local model", rank)

    # sync_bn_flow = nn.SyncBatchNorm.convert_sync_batchnorm(flow)
    flow = DDP(flow, device_ids=[rank], output_device=rank)

    print_peak_memory("Max memory allocated after creating DDP", rank)

    if use_zero:
        #https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html
        from torch.distributed.optim import ZeroRedundancyOptimizer
        optimizer = ZeroRedundancyOptimizer(
            flow.parameters(),
            optimizer_class=torch.optim.Adam,
            lr=lr,
            parameters_as_bucket_view=True,
        )
        # optimizer = torch.optim.Adam(flow.parameters(), lr=lr)
    else:
        optimizer = torch.optim.Adam(flow.parameters(), lr=lr)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           T_max=epochs)

    if load_dir is not None and load_epoch is not None:
        print(f'Loading model from {load_dir} at epoch {load_epoch}.')
        flow.module.load_state_dict(
            torch.load(f'gwpe/model_weights/{load_dir}/flow_{load_epoch}.pt',
                       map_location=rank))
        optimizer.load_state_dict(
            torch.load(
                f'gwpe/model_weights/{load_dir}/optimizer_{load_epoch}.pt',
                map_location=rank))
        if Path(f'gwpe/model_weights/{load_dir}/scheduler_{load_epoch}.pt'
                ).is_file():
            scheduler.load_state_dict(
                torch.load(
                    f'gwpe/model_weights/{load_dir}/scheduler_{load_epoch}.pt',
                    map_location=rank))

    # run training loop
    flow.train()
    train_loss = torch.zeros((1, ), device=rank, requires_grad=False)
    val_loss = torch.zeros((1, ), device=rank, requires_grad=False)

    disable_pbar = False if verbose and (rank
                                         == 0) else True  # tqdm progress bar
    with tqdm(total=len(dataloader) * epochs,
              disable=disable_pbar,
              desc=f'[{log_dir}] Training',
              postfix={'epoch': 0}) as progress:
        # with torch.profiler.profile(
        #     activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
        #     schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat),
        #     on_trace_ready=torch.profiler.tensorboard_trace_handler(f'gwpe/runs/{log_dir}'),
        #     record_shapes=True,
        #     with_stack=True
        # ) as profiler:

        for epoch in range(1, 1 + epochs):
            if rank == 0:
                progress.set_postfix({'epoch': epoch})
                progress.set_description(f'[{log_dir}] Training', refresh=True)

            # let all processes sync up before starting with a new epoch of training
            flow.train()
            distributed.barrier()

            iterator = iter(dataloader)
            coefficients, parameters = next(iterator)

            coefficients = coefficients.to(rank, non_blocking=True)
            parameters = parameters.to(rank, non_blocking=True)

            complete = False
            while not complete:
                optimizer.zero_grad()

                # if profile:
                # https://github.com/guyang3532/kineto/blob/readme/tb_plugin/docs/gpu_utilization.md
                ## WARNING: profiler may not handle async pinned memory transfer properly?
                # i.e. may not record CPU vs GPU wall times correctly
                # may be related to reported blocks per SM/achieved occupancy negative bug
                # this was an open issue for pytorch 1.9 as of july 9 - nightly may fix it
                # https://github.com/pytorch/kineto/issues/325#issuecomment-869362218
                # if (step >= (wait + warmup + active) * repeat):
                #     break

                # negative log-likelihood conditional on strain over mini-batch
                loss = -flow.module.log_prob(parameters,
                                             context=coefficients).mean()

                try:
                    # async get data from CPU and move to GPU during model forward
                    coefficients, parameters = next(iterator)

                    coefficients = coefficients.to(rank, non_blocking=True)
                    parameters = parameters.to(rank, non_blocking=True)

                except StopIteration:
                    # exit while loop if iterator is complete
                    complete = True

                loss.backward()

                print_peak_memory(
                    "Max memory allocated before optimizer step()", rank)
                optimizer.step()
                print_peak_memory(
                    "Max memory allocated after optimizer step()", rank)

                # if profile: profiler.step()

                # total loss summed over each sample in batch
                train_loss += loss.detach() * coefficients.shape[0]
                if rank == 0: progress.update(1)

            scheduler.step()

            # gather total loss during epoch between each GPU worker as list of tensors
            world_loss = [
                torch.ones_like(train_loss) for _ in range(world_size)
            ]
            distributed.all_gather(world_loss, train_loss)
            train_loss *= 0.0  # reset loss for next epoch

            if (interval != 0) and (epoch % interval == 0):
                # evaluate model on validation dataset
                flow.eval()
                with torch.no_grad():

                    iterator = iter(enumerate(val_loader))
                    step, (coefficients, parameters) = next(iterator)
                    coefficients = coefficients.to(rank, non_blocking=True)
                    parameters = parameters.to(rank, non_blocking=True)

                    if rank == 0:
                        val_progress = int(100 * step / len(val_loader))
                        progress.set_description(
                            f'[{log_dir}] Validating ({val_progress}%)',
                            refresh=True)

                    complete = False
                    while not complete:

                        # negative log-likelihood conditional on strain over mini-batch
                        loss = -flow.module.log_prob(
                            parameters, context=coefficients).mean()

                        try:
                            # async get data from CPU and move to GPU during model forward
                            step, (coefficients, parameters) = next(iterator)
                            coefficients = coefficients.to(rank,
                                                           non_blocking=True)
                            parameters = parameters.to(rank, non_blocking=True)

                            if rank == 0:
                                val_progress = int(100 * step /
                                                   len(val_loader))
                                progress.set_description(
                                    f'[{log_dir}] Validating ({val_progress}%)',
                                    refresh=True)

                        except StopIteration:
                            # exit while loop if iterator is complete
                            complete = True

                        # total loss summed over each sample in batch
                        val_loss += loss.detach() * coefficients.shape[0]

                    # gather total loss during epoch between each GPU worker as list of tensors
                    world_val_loss = [
                        torch.ones_like(val_loss) for _ in range(world_size)
                    ]
                    distributed.all_gather(world_val_loss, val_loss)
                    val_loss *= 0.0  # reset loss for next epoch

                    # validation posteriors
                    if rank == 0:
                        progress.set_description(
                            f'[{log_dir}] Sampling posteriors', refresh=True)

                    samples = flows.sample_flow(
                        flow.module,
                        n=10000,
                        context=val_context,
                        output_device='cuda',
                        dtype=torch.float32,
                    )[0]

                    # gather samples from all gpus
                    world_samples = [
                        torch.ones_like(samples) for _ in range(world_size)
                    ]
                    distributed.all_gather(world_samples, samples)

            if (rank == 0):
                progress.set_description(f'[{log_dir}] Sending to TensorBoard',
                                         refresh=True)

                scalars = {
                    'loss/train':
                    torch.cat(world_loss).sum().item() /
                    len(dataloader.dataset)
                }

                # every "interval" we generate samples for vis, else None
                corner_samples = None  # reset to None for epochs where there is no corner plot
                if (interval != 0) and (epoch % interval == 0):

                    scalars['loss/validation'] = torch.cat(
                        world_val_loss).sum().item() / len(val_loader.dataset)

                    # convert gw150914 samples to cpu and undo standardization
                    corner_samples = torch.stack(world_samples).cpu()
                    corner_samples *= torch.from_numpy(val_dataset.std)
                    corner_samples += torch.from_numpy(val_dataset.mean)

                # send data to async process to generate matplotlib figures
                queue.put((epoch, scalars, corner_samples))

                if (save != 0) and (epoch % save == 0):
                    # save checkpoint and write computationally expensive data to tb
                    torch.save(flow.module.state_dict(),
                               experiment_dir / f'flow_{epoch}.pt')

                    # if use_zero:
                    #     # needs to be called on all ranks
                    #     optimizer.consolidate_state_dict(to=0)

                    torch.save(optimizer.state_dict(),
                               experiment_dir / f'optimizer_{epoch}.pt')

                    if scheduler is not None:
                        torch.save(scheduler.state_dict(),
                                   experiment_dir / f'scheduler_{epoch}.pt')

    # destroy processes from distributed training
    if rank == 0:
        # to do - graceful way to shutdown workers
        # need to send message back from child process
        sleep_time = 120
        for i in range(sleep_time):
            progress.set_description(
                f'[{log_dir}] Shutting down in {sleep_time - i}s',
                refresh=True)
            time.sleep(1)

        tb_process.terminate()

    cleanup_nccl()
示例#7
0
def train_distributed(
    rank: int,
    world_size: int,
    lr: float=2e-4,
    batch_size: int=1000,
    epochs: int=500,
    interval: int=10,
    save: int=0,
    num_workers: int=4,
    num_basis: int=100,
    dataset: str='datasets',
    coefficient_noise: bool=False,
    use_zero: bool=False,
    verbose: bool=False,
):
    assert 0 < batch_size, "batch_size must be a positive integer."
    assert 0 < epochs, "epochs must be a positive integer."
    assert (0 <= interval) and (interval <= epochs), "Interval must be a non-negative integer between 0 and epochs."
    assert (0 <= save) and (save <= epochs), "Save must be a non-negative integer between 0 and epochs."


    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False

    # setup data distributed parallel training
    setup_nccl(rank, world_size)  # world size is total gpus
    torch.cuda.set_device(rank)  # rank is gpu index

    # directories
    data_dir = Path(f'/mnt/datahole/daniel/gravflows/{dataset}/train/')
    log_dir = f"{datetime.now().strftime('%b%d_%H-%M-%S')}_{os.uname().nodename}"
    save_dir = Path('lfigw/model_weights/')
    experiment_dir = save_dir / log_dir
    experiment_dir.mkdir(parents=True, exist_ok=True)

    # config files
    waveform_params_ini = str(data_dir / 'config_files/parameters.ini')
    extrinsics_ini = 'gwpe/config_files/extrinsics.ini'
    static_args_ini = 'gwpe/config_files/static_args.ini'

    # tensorboard
    if rank == 0:
        tb = SummaryWriter(f'lfigw/runs/{log_dir}')

    # training data
    dataset = lfigwWaveformDataset(
        n=num_basis,
        data_dir='lfigw/data/train',
        basis_dir='lfigw/data/basis/',
        psd_dir='data/events/GW150914',
        data_file='coefficients.npy',
        static_args_ini=static_args_ini,
        intrinsics_ini=waveform_params_ini,
        extrinsics_ini=extrinsics_ini,
        ifos=['H1','L1'],
        ref_ifo='H1',
        downcast=True,
        add_noise=True,
        coefficient_noise=coefficient_noise,
        distance_scale=True,
        time_shift=False,
    )

    sampler = DistributedSampler(
        dataset,
        shuffle=True,
        num_replicas=world_size,
        rank=rank,
        seed=rank,
    )

    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=True,
        persistent_workers=True,
        sampler=sampler,
        num_workers=num_workers,
        prefetch_factor=4,
        worker_init_fn=dataset._worker_init_fn,
        collate_fn=dataset._collate_fn,
    )

    # instantiate neural spline coupling flow
    flow = flows.create_NDE_model(
        input_dim=14,  # we do not predict coalescence time 
        context_dim=4*100,
        num_flow_steps=15,
        base_transform_kwargs={
            'base_transform_type': 'rq-coupling',
            'batch_norm': True,
            'num_transform_blocks': 10,
            'activation': 'elu',
            'hidden_dim': 512,  # default
            'dropout_probability': 0.0,  # default
            'num_bins': 8,  # default
            'tail_bound': 1.0,  # default
            'apply_unconditional_transform': False,  # default
        }
    )

    flow = flow.to(rank)
    # print_peak_memory("Max memory allocated after creating local model", rank)

    # sync_bn_flow = nn.SyncBatchNorm.convert_sync_batchnorm(flow)
    flow = DDP(flow, device_ids=[rank], output_device=rank)

    # print_peak_memory("Max memory allocated after creating DDP", rank)

    if use_zero:
        #https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html
        from torch.distributed.optim import ZeroRedundancyOptimizer
        optimizer = ZeroRedundancyOptimizer(
            flow.parameters(),
            optimizer_class=torch.optim.Adam,
            lr=lr,
            parameters_as_bucket_view=True,
        )
    else:
        optimizer = torch.optim.Adam(flow.parameters(), lr=lr)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)


    disable_pbar = False if verbose and (rank == 0) else True  # tqdm progress bar
    with tqdm(
        total=len(dataloader)*epochs,
        disable=disable_pbar,
        ncols=150,
        postfix={'epoch': 0},
        desc=f'[{log_dir}] Training'
    ) as progress:

        # run training loop
        train_loss = torch.zeros((1,), device=rank, requires_grad=False)

        for epoch in range(1, 1+epochs):
            flow.train()
            distributed.barrier()

            for coefficients, parameters in dataloader:
                if rank == 0:
                    progress.set_postfix({'epoch': epoch})
                    progress.set_description(f'[{log_dir}] Training', refresh=True)

                optimizer.zero_grad()
                
                coefficients = coefficients.to(rank, non_blocking=True)
                parameters = parameters.to(rank, non_blocking=True)

                # negative log-likelihood conditional on strain over mini-batch
                loss = -flow.module.log_prob(parameters, context=coefficients).mean()
                
                loss.backward()

                # print_peak_memory("Max memory allocated before optimizer step()", rank)
                optimizer.step()
                # print_peak_memory("Max memory allocated after optimizer step()", rank)

                # total loss summed over each sample in batch (scaled to lfigw)
                train_loss += loss.detach() * coefficients.shape[0] * (15/14)
                progress.update(1)

            scheduler.step()

            # gather total loss during epoch between each GPU worker as list of tensors
            world_loss = [torch.ones_like(train_loss) for _ in range(world_size)]
            distributed.all_gather(world_loss, train_loss)
            train_loss *= 0.0  # reset loss for next epoch

            if rank == 0:
                epoch_loss = torch.cat(world_loss).sum().item() / len(dataloader.dataset)
                tb.add_scalar('loss/train', epoch_loss, epoch)
                # if (interval != 0) and (epoch % interval == 0):

                if (save != 0) and (epoch % save == 0):
                    # save checkpoint and write computationally expensive data to tb
                    torch.save(flow.module.state_dict(), experiment_dir / f'flow_{epoch}.pt')
                    torch.save(optimizer.state_dict(), experiment_dir / f'optimizer_{epoch}.pt')
                    torch.save(scheduler.state_dict(), experiment_dir / f'scheduler_{epoch}.pt')


    cleanup_nccl()