def test_state_dict(self): """Check that the ZeroRedundancyOptimizer exposes the expected state dict interface, irrespective of the sharding. """ self.dist_init(self.rank) x = torch.tensor([1.0], device=DEVICE, requires_grad=True) o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.1, momentum=0.9) x.backward() o.step() self.assertEqual(x, torch.tensor([0.9], device=DEVICE)) self.assertEqual(o.optim.state[x]["momentum_buffer"], torch.tensor([1.0], device=DEVICE)) o.zero_grad() o.consolidate_state_dict( ) # Sync state dict in between replicas - even if there are none state_dict = o.state_dict() # Check that the state dict is pytorch-compliant key wise self.assertIn("param_groups", state_dict.keys()) self.assertIn("state", state_dict.keys()) # Check that the pulled state is what we expect, and that we have all the expected keys self.assertEqual(state_dict["param_groups"][0]["lr"], 0.1) self.assertEqual(state_dict["param_groups"][0]["momentum"], 0.9) self.assertFalse(state_dict["param_groups"][0]["nesterov"]) self.assertEqual(state_dict["param_groups"][0]["weight_decay"], 0.0) self.assertEqual(state_dict["param_groups"][0]["dampening"], 0.0) # Check that the pulled state and the .param_groups attribute are in sync for k in state_dict["param_groups"][0].keys(): if k != "params": self.assertEqual(state_dict["param_groups"][0][k], o.param_groups[0][k]) # Check that it's correctly loaded o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.01) o.load_state_dict(state_dict) # Check that state is correct and on proper device self.assertEqual(o.optim.state[x]["momentum_buffer"], torch.tensor([1.0], device=DEVICE)) # We should now be using a lr of 0.1, both within the optimizer # and as exposed by the .param_groups attribute assert o.param_groups[0]["lr"] == 0.1 x.backward() o.step() self.assertEqual(x, torch.tensor([0.71], device=DEVICE)) self.assertEqual(o.optim.state[x]["momentum_buffer"], torch.tensor([1.9], device=DEVICE)) # Check that the exposed param_groups are on the proper device self.assertEqual(o.param_groups[0]["params"][0].device, x.device)
def test_collect_shards(self): """ Check the state consolidation mechanism, and the state dict exposed by ZeroRedundancyOptimizer""" self.dist_init(self.rank) RECIPIENT_RANK = 0 # Run a dummy step so that the optimizer state dict exists batch, input_width, hidden, target_width = 3, 20, 10, 5 target = torch.rand((batch, target_width), device=self.device) inputs = torch.rand((batch, input_width), device=self.device) model = torch.nn.Sequential(torch.nn.Linear(input_width, hidden), torch.nn.Linear(hidden, target_width)) model.to(self.device) loss_fn = torch.nn.L1Loss() loss_fn.to(self.device) # With SGD, Momentum is required to get a state to shard optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=SGD, lr=0.1, momentum=0.99) def closure(): optimizer.zero_grad() output = model(inputs) loss = loss_fn(output, target) loss.backward() return loss _ = optimizer.step(closure=closure) # Update the optimizer state on the reference rank optimizer.consolidate_state_dict(to=RECIPIENT_RANK) # Fetch the state on the reference rank # - check that it has the correct size # - load it again if self.rank == RECIPIENT_RANK: optimizer_state_dict = optimizer.state_dict() self.assertEqual(len(optimizer_state_dict["state"]), len(list(model.parameters()))) else: optimizer_state_dict = {} optimizer_state_dict = _broadcast_object( optimizer_state_dict, src_rank=RECIPIENT_RANK, group=dist.group.WORLD, device=self.device, ) # Load the optimizer state dict, check that no exception is raised optimizer.load_state_dict(optimizer_state_dict)
def test_implicit_local_state_dict(self): """Check that it's possible to pull a local state dict .. warning: probably deprecated in the near future """ self.dist_init(self.rank) x = torch.tensor([1.0], device=DEVICE, requires_grad=True) o = ZeroRedundancyOptimizer([x], optim=SGD, lr=0.1) local_state_dict = o.state_dict() o = ZeroRedundancyOptimizer([x], optim=SGD, lr=0.01) o.load_state_dict(local_state_dict) # We should now be using a lr of 0.1. self.assertEqual(o.optim.param_groups[0]["lr"], 0.1) self.assertEqual(o.param_groups[0]["lr"], 0.1) x.backward() o.step() self.assertEqual(x, torch.tensor([0.9], device=DEVICE))
def state_dict(self): self.consolidate_state_dict() if dist.get_rank() == 0: return ZeroRedundancyOptimizer.state_dict(self) return None
def check_optimizer_equivalence( optimizer: Type[torch.optim.Optimizer]): # Any model works. Add one different buffer per rank model = torch.nn.Sequential( torch.nn.Linear(2, 3), torch.nn.Linear(3, 3), torch.nn.Linear(3, 3), ) model.register_buffer("test_buffer", torch.ones((1)) * self.rank) model.to(self.device) sharded_optimizer = ZeroRedundancyOptimizer( params=model.parameters(), optimizer_class=optimizer, lr=1e-3) sharded_ddp_model = DDP(module=model, device_ids=[self.rank], broadcast_buffers=True, find_unused_parameters=True) ddp_model_single = copy.deepcopy(model) ddp_model_single.to(self.device) ddp_optimizer = optimizer(ddp_model_single.parameters(), lr=1e-3) ddp_model = DDP(ddp_model_single, device_ids=[self.rank], broadcast_buffers=True, find_unused_parameters=True) # The model should be synchronized in between the ranks at construction time, check that check_same_model_params(sharded_ddp_model, ddp_model, "Models differ from the start") def check_step(): input_tensor = torch.rand((64, 2)) def closure_ddp(input_tensor=input_tensor): ddp_optimizer.zero_grad() ddp_loss = ddp_model(input_tensor).abs().sum() ddp_loss.backward() return ddp_loss def closure_sharded(input_tensor=input_tensor): sharded_optimizer.zero_grad() sharded_loss = sharded_ddp_model( input_tensor).abs().sum() sharded_loss.backward() return sharded_loss loss_ddp = cast(torch.Tensor, ddp_optimizer.step(closure=closure_ddp)) loss_sharded_optim = cast( torch.Tensor, sharded_optimizer.step(closure=closure_sharded)) assert torch.allclose( loss_ddp, loss_sharded_optim ), "Losses differ in between Pytorch optim and ZeroRedundancyOptimizer" check_same_model_params(sharded_ddp_model, ddp_model, "Models differ after a step") # The models should stay the same in between the ranks for i in range(BATCHS): check_step() # Change the models trainability, check that parity is maintained # only check after a couple of constant batchs to go through both regimes if i > BATCHS // 2: next(ddp_model.parameters()).requires_grad = bool(i % 2) next(sharded_ddp_model.parameters() ).requires_grad = bool(i % 2) # Check that the checkpoints are compatible reference_rank = 0 # - get states ddp_state_dict = ddp_optimizer.state_dict() sharded_optimizer.consolidate_state_dict(to=reference_rank) sharded_optim_state_dict = [ sharded_optimizer.state_dict() if self.rank == reference_rank else {} ] dist.broadcast_object_list(sharded_optim_state_dict, src=reference_rank, group=dist.group.WORLD) sharded_optim_state_dict = sharded_optim_state_dict[0] # - cross load the states # run one step and check that the models are still the same ddp_state_dict_ref = copy.deepcopy( ddp_state_dict) # OSS will remove some states ddp_optimizer.load_state_dict( sharded_optim_state_dict) # mixup on purpose ! sharded_optimizer.load_state_dict(ddp_state_dict) check_step() # - self load, rewind, check no problem # run one step and check that the models are still the same ddp_optimizer.load_state_dict(ddp_state_dict_ref) sharded_optimizer.load_state_dict(sharded_optim_state_dict) check_step()
def check_optimizer_equivalence( optimizer: Type[torch.optim.Optimizer]): # Any model works. Add one different buffer per rank model = torch.nn.Sequential( torch.nn.Linear(2, 3), torch.nn.Linear(3, 3), torch.nn.Linear(3, 3), ) model.register_buffer("test_buffer", torch.ones((1)) * self.rank) model.to(self.device) sharded_optimizer = ZeroRedundancyOptimizer( params=model.parameters(), optim=optimizer, lr=1e-3) sharded_ddp_model = DDP(module=model, device_ids=[self.rank], broadcast_buffers=True) ddp_model_single = copy.deepcopy(model) ddp_model_single.to(self.device) ddp_optimizer = optimizer(ddp_model_single.parameters(), lr=1e-3) ddp_model = DDP(ddp_model_single, device_ids=[self.rank], broadcast_buffers=True) def check_same_model_params(): for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups): for p, ddp_p in zip(pg["params"], ddp_pg["params"]): assert torch.allclose( p, ddp_p, atol=1e-3 ), f"Model parameters differ in between Pytorch optim and ZeroRedundancyOptimizer \n{p} {ddp_p}" for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()): assert torch.allclose( b, ddp_b ), "Model buffers differ in between Pytorch optim and ZeroRedundancyOptimizer" # The model should be synchronized in between the ranks at construction time, check that check_same_model_params() # The models should stay the same across multiple steps, losses should stay the same def check_step(): input_tensor = torch.rand((64, 2)) def closure_ddp(input_tensor=input_tensor): ddp_optimizer.zero_grad() ddp_loss = ddp_model(input_tensor).abs().sum() ddp_loss.backward() return ddp_loss def closure_sharded(input_tensor=input_tensor): sharded_optimizer.zero_grad() sharded_loss = sharded_ddp_model( input_tensor).abs().sum() sharded_loss.backward() return sharded_loss loss_ddp = cast(torch.Tensor, ddp_optimizer.step(closure=closure_ddp)) loss_sharded_optim = cast( torch.Tensor, sharded_optimizer.step(closure=closure_sharded)) assert torch.allclose( loss_ddp, loss_sharded_optim ), "Losses differ in between Pytorch optim and ZeroRedundancyOptimizer" check_same_model_params() for i in range(20): check_step() # Test state dict save/load/equivalence with pytorch # - save state for both sharded_optimizer.consolidate_state_dict() sharded_optimizer_state_dict = (sharded_optimizer.state_dict() if self.rank == RECIPIENT_RANK else torch.zeros(1)) ddp_state_dict = ddp_optimizer.state_dict() # - sync the saved state with all the ranks exchange_list = [sharded_optimizer_state_dict] dist.broadcast_object_list( exchange_list, src=RECIPIENT_RANK, group=dist.group.WORLD, ) sharded_optimizer_state_dict = exchange_list[0] # - cross load the states ddp_optimizer.load_state_dict(sharded_optimizer_state_dict) sharded_optimizer.load_state_dict(ddp_state_dict) # - run one step, and check that the models are still the same check_step()
def train( rank: int, world_size: int, lr: float = 5e-4, batch_size: int = 1000, epochs: int = 500, interval: int = 10, save: int = 100, num_workers: int = 4, num_basis: int = 100, dataset: str = 'datasets', load_dir: Optional[str] = None, load_epoch: Optional[int] = None, coefficient_noise: bool = False, verbose: bool = False, use_zero: bool = False, ): assert 0 < batch_size, "batch_size must be a positive integer." assert 0 < epochs, "epochs must be a positive integer." assert (0 <= interval) and ( interval <= epochs ), "Interval must be a non-negative integer between 0 and epochs." assert (0 <= save) and ( save <= epochs), "Save must be a non-negative integer between 0 and epochs." # setup data distributed parallel training setup_nccl(rank, world_size) # world size is total gpus torch.cuda.set_device(rank) # rank is gpu index # directories if rank == 0: print(f"Loading data from {dataset}...") data_dir = Path(f'/mnt/datahole/daniel/gravflows/{dataset}/train/') val_dir = Path(f'/mnt/datahole/daniel/gravflows/{dataset}/validation/') noise_dir = Path('/mnt/datahole/daniel/gwosc/O1') psd_dir = Path(f"/mnt/datahole/daniel/gravflows/{dataset}/train/PSD/") basis_dir = Path(f'/mnt/datahole/daniel/gravflows/{dataset}/basis/') log_dir = f"{datetime.now().strftime('%b%d_%H-%M-%S')}_{os.uname().nodename}" save_dir = Path('gwpe/model_weights/') experiment_dir = save_dir / log_dir experiment_dir.mkdir(parents=True, exist_ok=True) # config files waveform_params_ini = str(data_dir / 'config_files/parameters.ini') extrinsics_ini = 'gwpe/config_files/extrinsics.ini' static_args_ini = str(data_dir / 'config_files/static_args.ini') # validation # training data # dataset = BasisCoefficientsDataset( # data_dir=data_dir, # basis_dir=basis_dir, # static_args_ini=static_args_ini, # parameters_ini=waveform_params_ini, # ) # dataset = BasisEncoderDataset( # n=num_basis, # data_dir=data_dir, # basis_dir=basis_dir, # static_args_ini=static_args_ini, # intrinsics_ini=waveform_params_ini, # extrinsics_ini=extrinsics_ini, # psd_dir=psd_dir, # ifos=['H1','L1'], # ref_ifo='H1', # downcast=True, # add_noise=True, # coefficient_noise=coefficient_noise, # ) dataset = LFIGWDataset( n=100, data_dir=data_dir, basis_dir=basis_dir, static_args_ini=static_args_ini, data_file='coefficients.npy', intrinsics_ini=waveform_params_ini, extrinsics_ini=extrinsics_ini, psd_dir=psd_dir, ifos=['H1', 'L1'], ref_ifo='H1', downcast=True, add_noise=True, distance_scale=True, time_shift=False, ) sampler = DistributedSampler( dataset, shuffle=False, num_replicas=world_size, rank=rank, seed=rank, ) dataloader = DataLoader( dataset, shuffle=False, num_workers=num_workers, batch_size=batch_size, sampler=sampler, pin_memory=True, persistent_workers=True, prefetch_factor=2, worker_init_fn=dataset._worker_init_fn, collate_fn=dataset._collate_fn, ) # validation data val_dataset = LFIGWDataset( n=100, data_dir=data_dir, basis_dir=basis_dir, static_args_ini=static_args_ini, data_file='coefficients.npy', intrinsics_ini=waveform_params_ini, extrinsics_ini=extrinsics_ini, psd_dir=psd_dir, ifos=['H1', 'L1'], ref_ifo='H1', downcast=True, add_noise=True, coefficient_noise=coefficient_noise, distance_scale=True, time_shift=False, ) # val_dataset = BasisCoefficientsDataset( # data_dir=val_dir, # basis_dir=basis_dir, # static_args_ini=static_args_ini, # parameters_ini=[waveform_params_ini, extrinsics_ini], # coefficient_noise=coefficient_noise, # ) val_sampler = DistributedSampler( val_dataset, shuffle=False, num_replicas=world_size, rank=rank, seed=rank, ) val_loader = DataLoader( val_dataset, shuffle=False, num_workers=num_workers, batch_size=batch_size, sampler=val_sampler, pin_memory=True, prefetch_factor=4, worker_init_fn=val_dataset._worker_init_fn, collate_fn=val_dataset._collate_fn, ) # validation data if interval != 0: # specify indices in validation dataset to validate samples min_idx = val_dataset.parameters.distance.argmin() max_idx = val_dataset.parameters.distance.argmax() median_idx = val_dataset.parameters.loc[ val_dataset.parameters.distance == val_dataset.parameters.distance. quantile(interpolation='nearest')].index[0] if rank == 0: figure_titles = [ 'GW150914', 'Min Distance', f'Median Distance', f'Max Distance' ] # validation ground truths for posterior sampling val_gts = torch.stack([ torch.zeros(len(val_dataset.parameters.columns), dtype=torch.float32), # gw150914 dummy gt torch.tensor(val_dataset.parameters.iloc[min_idx].values, dtype=torch.float32), # rank 1 torch.tensor(val_dataset.parameters.iloc[median_idx].values, dtype=torch.float32), # rank 2 torch.tensor(val_dataset.parameters.iloc[max_idx].values, dtype=torch.float32), # rank 3 ]) with torch.no_grad(): # load data from file manually (rather than using val_dataset._worker_init_fn) val_coefficients = np.load(val_dataset.data_dir / val_dataset.data_file, mmap_mode='c') # generate coefficients on cpu - we want to send this to tensorboard (rank 0) before sending to gpus val_coefficients = torch.cat([ torch.from_numpy( generate_gw150914_context(num_basis, noise_dir, psd_dir, basis_dir, static_args_ini))[None], torch.tensor(val_coefficients[[min_idx, median_idx, max_idx]]), ], dim=0).to(dtype=torch.complex64) # place one of each stacked tensor onto corresponding gpu rank val_context = val_coefficients[ rank] * val_dataset.standardization[:, :num_basis] val_context = val_context.to(device=rank) val_context = torch.cat([val_context.real, val_context.imag], dim=0) val_context = val_context.reshape(val_context.shape[0] * val_context.shape[1])[None] else: figure_titles = None val_gts = None val_coefficients = None # set torch profiling runs # wait = 1 # ignore first batch # warmup = 1 # active = 4 # repeat = 2 # tensorboard if rank == 0: # tb = SummaryWriter(f'gwpe/runs/{log_dir}') queue = mp.SimpleQueue() tb_process = mp.Process(target=tensorboard_writer, args=( queue, f'gwpe/runs/{log_dir}', val_dataset.generator.parameters, val_dataset.generator.latex, static_args_ini, basis_dir, num_basis, val_coefficients, val_gts, figure_titles, )) tb_process.start() # instantiate neural spline coupling flow flow = flows.create_NDE_model( input_dim=14, # we do not predict coalescence time context_dim=4 * num_basis, num_flow_steps=15, base_transform_kwargs={ 'base_transform_type': 'rq-coupling', 'batch_norm': True, 'num_transform_blocks': 10, 'activation': 'elu', }) flow = flow.to(rank) print_peak_memory("Max memory allocated after creating local model", rank) # sync_bn_flow = nn.SyncBatchNorm.convert_sync_batchnorm(flow) flow = DDP(flow, device_ids=[rank], output_device=rank) print_peak_memory("Max memory allocated after creating DDP", rank) if use_zero: #https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html from torch.distributed.optim import ZeroRedundancyOptimizer optimizer = ZeroRedundancyOptimizer( flow.parameters(), optimizer_class=torch.optim.Adam, lr=lr, parameters_as_bucket_view=True, ) # optimizer = torch.optim.Adam(flow.parameters(), lr=lr) else: optimizer = torch.optim.Adam(flow.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) if load_dir is not None and load_epoch is not None: print(f'Loading model from {load_dir} at epoch {load_epoch}.') flow.module.load_state_dict( torch.load(f'gwpe/model_weights/{load_dir}/flow_{load_epoch}.pt', map_location=rank)) optimizer.load_state_dict( torch.load( f'gwpe/model_weights/{load_dir}/optimizer_{load_epoch}.pt', map_location=rank)) if Path(f'gwpe/model_weights/{load_dir}/scheduler_{load_epoch}.pt' ).is_file(): scheduler.load_state_dict( torch.load( f'gwpe/model_weights/{load_dir}/scheduler_{load_epoch}.pt', map_location=rank)) # run training loop flow.train() train_loss = torch.zeros((1, ), device=rank, requires_grad=False) val_loss = torch.zeros((1, ), device=rank, requires_grad=False) disable_pbar = False if verbose and (rank == 0) else True # tqdm progress bar with tqdm(total=len(dataloader) * epochs, disable=disable_pbar, desc=f'[{log_dir}] Training', postfix={'epoch': 0}) as progress: # with torch.profiler.profile( # activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], # schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat), # on_trace_ready=torch.profiler.tensorboard_trace_handler(f'gwpe/runs/{log_dir}'), # record_shapes=True, # with_stack=True # ) as profiler: for epoch in range(1, 1 + epochs): if rank == 0: progress.set_postfix({'epoch': epoch}) progress.set_description(f'[{log_dir}] Training', refresh=True) # let all processes sync up before starting with a new epoch of training flow.train() distributed.barrier() iterator = iter(dataloader) coefficients, parameters = next(iterator) coefficients = coefficients.to(rank, non_blocking=True) parameters = parameters.to(rank, non_blocking=True) complete = False while not complete: optimizer.zero_grad() # if profile: # https://github.com/guyang3532/kineto/blob/readme/tb_plugin/docs/gpu_utilization.md ## WARNING: profiler may not handle async pinned memory transfer properly? # i.e. may not record CPU vs GPU wall times correctly # may be related to reported blocks per SM/achieved occupancy negative bug # this was an open issue for pytorch 1.9 as of july 9 - nightly may fix it # https://github.com/pytorch/kineto/issues/325#issuecomment-869362218 # if (step >= (wait + warmup + active) * repeat): # break # negative log-likelihood conditional on strain over mini-batch loss = -flow.module.log_prob(parameters, context=coefficients).mean() try: # async get data from CPU and move to GPU during model forward coefficients, parameters = next(iterator) coefficients = coefficients.to(rank, non_blocking=True) parameters = parameters.to(rank, non_blocking=True) except StopIteration: # exit while loop if iterator is complete complete = True loss.backward() print_peak_memory( "Max memory allocated before optimizer step()", rank) optimizer.step() print_peak_memory( "Max memory allocated after optimizer step()", rank) # if profile: profiler.step() # total loss summed over each sample in batch train_loss += loss.detach() * coefficients.shape[0] if rank == 0: progress.update(1) scheduler.step() # gather total loss during epoch between each GPU worker as list of tensors world_loss = [ torch.ones_like(train_loss) for _ in range(world_size) ] distributed.all_gather(world_loss, train_loss) train_loss *= 0.0 # reset loss for next epoch if (interval != 0) and (epoch % interval == 0): # evaluate model on validation dataset flow.eval() with torch.no_grad(): iterator = iter(enumerate(val_loader)) step, (coefficients, parameters) = next(iterator) coefficients = coefficients.to(rank, non_blocking=True) parameters = parameters.to(rank, non_blocking=True) if rank == 0: val_progress = int(100 * step / len(val_loader)) progress.set_description( f'[{log_dir}] Validating ({val_progress}%)', refresh=True) complete = False while not complete: # negative log-likelihood conditional on strain over mini-batch loss = -flow.module.log_prob( parameters, context=coefficients).mean() try: # async get data from CPU and move to GPU during model forward step, (coefficients, parameters) = next(iterator) coefficients = coefficients.to(rank, non_blocking=True) parameters = parameters.to(rank, non_blocking=True) if rank == 0: val_progress = int(100 * step / len(val_loader)) progress.set_description( f'[{log_dir}] Validating ({val_progress}%)', refresh=True) except StopIteration: # exit while loop if iterator is complete complete = True # total loss summed over each sample in batch val_loss += loss.detach() * coefficients.shape[0] # gather total loss during epoch between each GPU worker as list of tensors world_val_loss = [ torch.ones_like(val_loss) for _ in range(world_size) ] distributed.all_gather(world_val_loss, val_loss) val_loss *= 0.0 # reset loss for next epoch # validation posteriors if rank == 0: progress.set_description( f'[{log_dir}] Sampling posteriors', refresh=True) samples = flows.sample_flow( flow.module, n=10000, context=val_context, output_device='cuda', dtype=torch.float32, )[0] # gather samples from all gpus world_samples = [ torch.ones_like(samples) for _ in range(world_size) ] distributed.all_gather(world_samples, samples) if (rank == 0): progress.set_description(f'[{log_dir}] Sending to TensorBoard', refresh=True) scalars = { 'loss/train': torch.cat(world_loss).sum().item() / len(dataloader.dataset) } # every "interval" we generate samples for vis, else None corner_samples = None # reset to None for epochs where there is no corner plot if (interval != 0) and (epoch % interval == 0): scalars['loss/validation'] = torch.cat( world_val_loss).sum().item() / len(val_loader.dataset) # convert gw150914 samples to cpu and undo standardization corner_samples = torch.stack(world_samples).cpu() corner_samples *= torch.from_numpy(val_dataset.std) corner_samples += torch.from_numpy(val_dataset.mean) # send data to async process to generate matplotlib figures queue.put((epoch, scalars, corner_samples)) if (save != 0) and (epoch % save == 0): # save checkpoint and write computationally expensive data to tb torch.save(flow.module.state_dict(), experiment_dir / f'flow_{epoch}.pt') # if use_zero: # # needs to be called on all ranks # optimizer.consolidate_state_dict(to=0) torch.save(optimizer.state_dict(), experiment_dir / f'optimizer_{epoch}.pt') if scheduler is not None: torch.save(scheduler.state_dict(), experiment_dir / f'scheduler_{epoch}.pt') # destroy processes from distributed training if rank == 0: # to do - graceful way to shutdown workers # need to send message back from child process sleep_time = 120 for i in range(sleep_time): progress.set_description( f'[{log_dir}] Shutting down in {sleep_time - i}s', refresh=True) time.sleep(1) tb_process.terminate() cleanup_nccl()
def train_distributed( rank: int, world_size: int, lr: float=2e-4, batch_size: int=1000, epochs: int=500, interval: int=10, save: int=0, num_workers: int=4, num_basis: int=100, dataset: str='datasets', coefficient_noise: bool=False, use_zero: bool=False, verbose: bool=False, ): assert 0 < batch_size, "batch_size must be a positive integer." assert 0 < epochs, "epochs must be a positive integer." assert (0 <= interval) and (interval <= epochs), "Interval must be a non-negative integer between 0 and epochs." assert (0 <= save) and (save <= epochs), "Save must be a non-negative integer between 0 and epochs." # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False # setup data distributed parallel training setup_nccl(rank, world_size) # world size is total gpus torch.cuda.set_device(rank) # rank is gpu index # directories data_dir = Path(f'/mnt/datahole/daniel/gravflows/{dataset}/train/') log_dir = f"{datetime.now().strftime('%b%d_%H-%M-%S')}_{os.uname().nodename}" save_dir = Path('lfigw/model_weights/') experiment_dir = save_dir / log_dir experiment_dir.mkdir(parents=True, exist_ok=True) # config files waveform_params_ini = str(data_dir / 'config_files/parameters.ini') extrinsics_ini = 'gwpe/config_files/extrinsics.ini' static_args_ini = 'gwpe/config_files/static_args.ini' # tensorboard if rank == 0: tb = SummaryWriter(f'lfigw/runs/{log_dir}') # training data dataset = lfigwWaveformDataset( n=num_basis, data_dir='lfigw/data/train', basis_dir='lfigw/data/basis/', psd_dir='data/events/GW150914', data_file='coefficients.npy', static_args_ini=static_args_ini, intrinsics_ini=waveform_params_ini, extrinsics_ini=extrinsics_ini, ifos=['H1','L1'], ref_ifo='H1', downcast=True, add_noise=True, coefficient_noise=coefficient_noise, distance_scale=True, time_shift=False, ) sampler = DistributedSampler( dataset, shuffle=True, num_replicas=world_size, rank=rank, seed=rank, ) dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=False, pin_memory=True, persistent_workers=True, sampler=sampler, num_workers=num_workers, prefetch_factor=4, worker_init_fn=dataset._worker_init_fn, collate_fn=dataset._collate_fn, ) # instantiate neural spline coupling flow flow = flows.create_NDE_model( input_dim=14, # we do not predict coalescence time context_dim=4*100, num_flow_steps=15, base_transform_kwargs={ 'base_transform_type': 'rq-coupling', 'batch_norm': True, 'num_transform_blocks': 10, 'activation': 'elu', 'hidden_dim': 512, # default 'dropout_probability': 0.0, # default 'num_bins': 8, # default 'tail_bound': 1.0, # default 'apply_unconditional_transform': False, # default } ) flow = flow.to(rank) # print_peak_memory("Max memory allocated after creating local model", rank) # sync_bn_flow = nn.SyncBatchNorm.convert_sync_batchnorm(flow) flow = DDP(flow, device_ids=[rank], output_device=rank) # print_peak_memory("Max memory allocated after creating DDP", rank) if use_zero: #https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html from torch.distributed.optim import ZeroRedundancyOptimizer optimizer = ZeroRedundancyOptimizer( flow.parameters(), optimizer_class=torch.optim.Adam, lr=lr, parameters_as_bucket_view=True, ) else: optimizer = torch.optim.Adam(flow.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) disable_pbar = False if verbose and (rank == 0) else True # tqdm progress bar with tqdm( total=len(dataloader)*epochs, disable=disable_pbar, ncols=150, postfix={'epoch': 0}, desc=f'[{log_dir}] Training' ) as progress: # run training loop train_loss = torch.zeros((1,), device=rank, requires_grad=False) for epoch in range(1, 1+epochs): flow.train() distributed.barrier() for coefficients, parameters in dataloader: if rank == 0: progress.set_postfix({'epoch': epoch}) progress.set_description(f'[{log_dir}] Training', refresh=True) optimizer.zero_grad() coefficients = coefficients.to(rank, non_blocking=True) parameters = parameters.to(rank, non_blocking=True) # negative log-likelihood conditional on strain over mini-batch loss = -flow.module.log_prob(parameters, context=coefficients).mean() loss.backward() # print_peak_memory("Max memory allocated before optimizer step()", rank) optimizer.step() # print_peak_memory("Max memory allocated after optimizer step()", rank) # total loss summed over each sample in batch (scaled to lfigw) train_loss += loss.detach() * coefficients.shape[0] * (15/14) progress.update(1) scheduler.step() # gather total loss during epoch between each GPU worker as list of tensors world_loss = [torch.ones_like(train_loss) for _ in range(world_size)] distributed.all_gather(world_loss, train_loss) train_loss *= 0.0 # reset loss for next epoch if rank == 0: epoch_loss = torch.cat(world_loss).sum().item() / len(dataloader.dataset) tb.add_scalar('loss/train', epoch_loss, epoch) # if (interval != 0) and (epoch % interval == 0): if (save != 0) and (epoch % save == 0): # save checkpoint and write computationally expensive data to tb torch.save(flow.module.state_dict(), experiment_dir / f'flow_{epoch}.pt') torch.save(optimizer.state_dict(), experiment_dir / f'optimizer_{epoch}.pt') torch.save(scheduler.state_dict(), experiment_dir / f'scheduler_{epoch}.pt') cleanup_nccl()
def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer]): # Any model works. Add one different buffer per rank model = torch.nn.Sequential( torch.nn.Linear(2, 3), torch.nn.Linear(3, 3), torch.nn.Linear(3, 3), ) model.register_buffer("test_buffer", torch.ones((1)) * self.rank) model.to(self.device) sharded_optimizer = ZeroRedundancyOptimizer(params=model.parameters(), optim=optimizer, lr=1e-3) sharded_ddp_model = DDP(module=model, device_ids=[self.rank], broadcast_buffers=True) ddp_model_single = copy.deepcopy(model) ddp_model_single.to(self.device) ddp_optimizer = optimizer(ddp_model_single.parameters(), lr=1e-3) ddp_model = DDP(ddp_model_single, device_ids=[self.rank], broadcast_buffers=True) def check_same_model_params(): for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups): for p, ddp_p in zip(pg["params"], ddp_pg["params"]): assert torch.allclose( p, ddp_p, atol=1e-3 ), f"Model parameters differ in between Pytorch optim and ZeroRedundancyOptimizer \n{p} {ddp_p}" for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()): assert torch.allclose( b, ddp_b ), "Model buffers differ in between Pytorch optim and ZeroRedundancyOptimizer" # The model should be synchronized in between the ranks at construction time, check that check_same_model_params() def check_step(): input_tensor = torch.rand((64, 2)) def closure_ddp(input_tensor=input_tensor): ddp_optimizer.zero_grad() ddp_loss = ddp_model(input_tensor).abs().sum() ddp_loss.backward() return ddp_loss def closure_sharded(input_tensor=input_tensor): sharded_optimizer.zero_grad() sharded_loss = sharded_ddp_model(input_tensor).abs().sum() sharded_loss.backward() return sharded_loss loss_ddp = cast(torch.Tensor, ddp_optimizer.step(closure=closure_ddp)) loss_sharded_optim = cast(torch.Tensor, sharded_optimizer.step(closure=closure_sharded)) assert torch.allclose( loss_ddp, loss_sharded_optim ), "Losses differ in between Pytorch optim and ZeroRedundancyOptimizer" check_same_model_params() # The models should stay the same in between the ranks for i in range(20): check_step() # Check that the checkpoints are compatible reference_rank = 0 # - get states ddp_state_dict = ddp_optimizer.state_dict() sharded_optimizer.consolidate_state_dict(recipient_rank=reference_rank) sharded_optim_state_dict = [sharded_optimizer.state_dict() if self.rank == reference_rank else {}] dist.broadcast_object_list(sharded_optim_state_dict, src=reference_rank, group=dist.group.WORLD) sharded_optim_state_dict = sharded_optim_state_dict[0] # - cross load the states # run one step and check that the models are still the same ddp_state_dict_ref = copy.deepcopy(ddp_state_dict) # OSS will remove some states ddp_optimizer.load_state_dict(sharded_optim_state_dict) # mixup on purpose ! sharded_optimizer.load_state_dict(ddp_state_dict) check_step() # - self load, rewind, check no problem # run one step and check that the models are still the same ddp_optimizer.load_state_dict(ddp_state_dict_ref) sharded_optimizer.load_state_dict(sharded_optim_state_dict) check_step()