def test_state_dict(self): """Check that the ZeroRedundancyOptimizer exposes the expected state dict interface, irrespective of the sharding. """ self.dist_init(self.rank) x = torch.tensor([1.0], device=DEVICE, requires_grad=True) o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.1, momentum=0.9) x.backward() o.step() self.assertEqual(x, torch.tensor([0.9], device=DEVICE)) self.assertEqual(o.optim.state[x]["momentum_buffer"], torch.tensor([1.0], device=DEVICE)) o.zero_grad() o.consolidate_state_dict( ) # Sync state dict in between replicas - even if there are none state_dict = o.state_dict() # Check that the state dict is pytorch-compliant key wise self.assertIn("param_groups", state_dict.keys()) self.assertIn("state", state_dict.keys()) # Check that the pulled state is what we expect, and that we have all the expected keys self.assertEqual(state_dict["param_groups"][0]["lr"], 0.1) self.assertEqual(state_dict["param_groups"][0]["momentum"], 0.9) self.assertFalse(state_dict["param_groups"][0]["nesterov"]) self.assertEqual(state_dict["param_groups"][0]["weight_decay"], 0.0) self.assertEqual(state_dict["param_groups"][0]["dampening"], 0.0) # Check that the pulled state and the .param_groups attribute are in sync for k in state_dict["param_groups"][0].keys(): if k != "params": self.assertEqual(state_dict["param_groups"][0][k], o.param_groups[0][k]) # Check that it's correctly loaded o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.01) o.load_state_dict(state_dict) # Check that state is correct and on proper device self.assertEqual(o.optim.state[x]["momentum_buffer"], torch.tensor([1.0], device=DEVICE)) # We should now be using a lr of 0.1, both within the optimizer # and as exposed by the .param_groups attribute assert o.param_groups[0]["lr"] == 0.1 x.backward() o.step() self.assertEqual(x, torch.tensor([0.71], device=DEVICE)) self.assertEqual(o.optim.state[x]["momentum_buffer"], torch.tensor([1.9], device=DEVICE)) # Check that the exposed param_groups are on the proper device self.assertEqual(o.param_groups[0]["params"][0].device, x.device)
def test_zero_grad(self): """Check that the zero_grad attribute is properly handled""" self.dist_init(self.rank) x = torch.rand(1) m = torch.nn.Linear(1, 1) o = ZeroRedundancyOptimizer(m.parameters(), optimizer_class=SGD, lr=0.1) y = m(x) y.backward(x) self.assertNotEqual(m.weight.grad, torch.zeros_like(m.weight)) self.assertNotEqual(m.weight.grad, torch.zeros_like(m.weight)) o.zero_grad() self.assertFalse(m.weight.grad) self.assertFalse(m.bias.grad)
def test_lr_scheduler(self): """ Check that a normal torch lr_scheduler is usable with ZeroRedundancyOptimizer""" self.dist_init(self.rank) x = torch.tensor([1.0], device=DEVICE, requires_grad=True) x2 = torch.tensor([1.0], device=DEVICE, requires_grad=True) o = ZeroRedundancyOptimizer([x], optimizer_class=SGD, lr=0.01) o2 = torch.optim.SGD([x2], lr=0.01) s = torch.optim.lr_scheduler.StepLR(o, 1) s2 = torch.optim.lr_scheduler.StepLR(o2, 1) for _ in range(5): x.backward() o.zero_grad() o.step() s.step() x2.backward() o2.zero_grad() o2.step() s2.step() self.assertEqual(x, x2)
def _test_ddp_zero_overlap(self, device, hook_constructor, gradient_as_bucket_view, static_graph): SGD_LR = 0.01 SGD_MOMENTUM = 0.9 SGD_WEIGHT_DECAY = 0.001 NUM_INPUTS = 5 torch.manual_seed(0) torch.cuda.manual_seed(0) rank = self.rank is_gpu = device.type == "cuda" if BACKEND == dist.Backend.NCCL and is_gpu: torch.cuda.set_device(device) models_to_test = [ (torch.nn.Sequential(torch.nn.Linear(1000, 2000), torch.nn.Linear(2000, 500)), [torch.randn(1, 1000).to(device) for _ in range(NUM_INPUTS)]), ] if HAS_TORCHVISION: models_to_test.append((torchvision.models.resnet50(), [ torch.randn(1, 3, 3, 1000).to(device) for _ in range(NUM_INPUTS) ])) for (model, inputs) in models_to_test: # Enable determinism in cudnn operators with torch.backends.cudnn.flags(enabled=True, deterministic=True, benchmark=False): device_ids = [rank] if is_gpu else None # Set up the DDP model overlapping with ZeRO ddp_model_overlap = DDP( copy.deepcopy(model).to(device), device_ids=device_ids, gradient_as_bucket_view=gradient_as_bucket_view) if static_graph: ddp_model_overlap._set_static_graph() zero_optim = ZeroRedundancyOptimizer( ddp_model_overlap.parameters(), optimizer_class=torch.optim.SGD, overlap_with_ddp=True, lr=SGD_LR, momentum=SGD_MOMENTUM, weight_decay=SGD_WEIGHT_DECAY, _allow_empty_param_list=True) ddp_model_overlap.register_comm_hook( None, hook_constructor(allreduce_hook, ddp_model_overlap, zero_optim)) # Set up the DDP model with local optimizer ddp_model_local = DDP( copy.deepcopy(model).to(device), device_ids=device_ids, gradient_as_bucket_view=gradient_as_bucket_view) if static_graph: ddp_model_local._set_static_graph() local_optim = torch.optim.SGD(ddp_model_local.parameters(), lr=SGD_LR, momentum=SGD_MOMENTUM, weight_decay=SGD_WEIGHT_DECAY) # Check that the parameters match initially for p1, p2 in zip(ddp_model_overlap.parameters(), ddp_model_local.parameters()): self.assertEqual(p1, p2) # Save the parameters to ensure they were updated init_params_overlap = copy.deepcopy( list(ddp_model_overlap.parameters())) # Ensure that this test runs independently dist.barrier() # Run the DDP model overlapping with ZeRO # NOTE: Overlapping currently requires 2 or 3 warmup iterations # to ensure DDP buckets have been rebuilt (depending on the # value of `static_graph`) num_warmup_inputs = 2 if not static_graph else 3 for input in inputs[:num_warmup_inputs]: output = ddp_model_overlap(input) loss = output.sum() loss.backward() zero_optim.step() for input in inputs: zero_optim.zero_grad() output = ddp_model_overlap(input) loss = output.sum() loss.backward() zero_optim.step() # Run the DDP model with local optimizer for input in inputs: local_optim.zero_grad() output = ddp_model_local(input) loss = output.sum() loss.backward() local_optim.step() dist.barrier() # Check that the parameters are equal for p1, p2 in zip(ddp_model_overlap.parameters(), ddp_model_local.parameters()): self.assertEqual(p1, p2) # Check that the parameters were updated self.assertNotEqual(init_params_overlap, list(ddp_model_overlap.parameters())) # Ensure that this test runs independently dist.barrier()
def run_process(): '''Run process This is what is actually run on each process. ''' # Get distributed parameters rank = dist.get_rank() local_rank = int(os.environ["LOCAL_RANK"]) world_size = dist.get_world_size() # Initialize data_loader context_size = 512 batch_size = 32 corpus_length = 1024 vocab_size = 2**8 dataset = RandomCorpus(corpus_length, context_size, vocab_size) sampler = DistributedSampler(dataset, shuffle=True, drop_last=False) data_loader = DataLoader( dataset=dataset, batch_size=batch_size, sampler=sampler, ) # Initialize model model = GPT(vocab_size, context_size, verbose=True) device = torch.device(f"cuda:{local_rank}") model.to(device) # Prepare for distributed data parallelism model = DistributedDataParallel(model, device_ids=[rank], output_device=rank) # The learning rate is adapted for the total batch_size in tokens learning_rate = 6e-4 * (batch_size * world_size * context_size / 5e5) # ZeroRedundancyOptimizer reduces the memory footprint of the Optimizer opt = ZeroRedundancyOptimizer( model.parameters(), optimizer_class=optim.Adam, lr=learning_rate, ) loss_func = nn.CrossEntropyLoss() # Initialize logger instance to see performance writer = BenchmarkWriter() # Actual training global_step = 0 n_epochs = 10 for epoch in range(n_epochs): model.train() sampler.set_epoch(epoch) # for correct shuffling for sequence, in data_loader: opt.zero_grad() # Shift so that prediction is next token for each token sequence = sequence.to(device) logits = model(sequence[..., :-1].contiguous()) target = sequence[..., 1:].contiguous() # Flatten the tokens when calculating loss loss = loss_func( logits.flatten(end_dim=-2), target.flatten(), ) loss.backward() opt.step() # This will also log the wall time if rank == 0: global_step += batch_size * world_size writer.add_scalar("Loss", loss.item(), global_step=global_step) if rank == 0: print("Epoch:", epoch) if rank == 0: writer.benchmark_results(burn_in_steps=2 * corpus_length, step_unit="seq") writer.close() return model
def train( rank: int, world_size: int, lr: float = 5e-4, batch_size: int = 1000, epochs: int = 500, interval: int = 10, save: int = 100, num_workers: int = 4, num_basis: int = 100, dataset: str = 'datasets', load_dir: Optional[str] = None, load_epoch: Optional[int] = None, coefficient_noise: bool = False, verbose: bool = False, use_zero: bool = False, ): assert 0 < batch_size, "batch_size must be a positive integer." assert 0 < epochs, "epochs must be a positive integer." assert (0 <= interval) and ( interval <= epochs ), "Interval must be a non-negative integer between 0 and epochs." assert (0 <= save) and ( save <= epochs), "Save must be a non-negative integer between 0 and epochs." # setup data distributed parallel training setup_nccl(rank, world_size) # world size is total gpus torch.cuda.set_device(rank) # rank is gpu index # directories if rank == 0: print(f"Loading data from {dataset}...") data_dir = Path(f'/mnt/datahole/daniel/gravflows/{dataset}/train/') val_dir = Path(f'/mnt/datahole/daniel/gravflows/{dataset}/validation/') noise_dir = Path('/mnt/datahole/daniel/gwosc/O1') psd_dir = Path(f"/mnt/datahole/daniel/gravflows/{dataset}/train/PSD/") basis_dir = Path(f'/mnt/datahole/daniel/gravflows/{dataset}/basis/') log_dir = f"{datetime.now().strftime('%b%d_%H-%M-%S')}_{os.uname().nodename}" save_dir = Path('gwpe/model_weights/') experiment_dir = save_dir / log_dir experiment_dir.mkdir(parents=True, exist_ok=True) # config files waveform_params_ini = str(data_dir / 'config_files/parameters.ini') extrinsics_ini = 'gwpe/config_files/extrinsics.ini' static_args_ini = str(data_dir / 'config_files/static_args.ini') # validation # training data # dataset = BasisCoefficientsDataset( # data_dir=data_dir, # basis_dir=basis_dir, # static_args_ini=static_args_ini, # parameters_ini=waveform_params_ini, # ) # dataset = BasisEncoderDataset( # n=num_basis, # data_dir=data_dir, # basis_dir=basis_dir, # static_args_ini=static_args_ini, # intrinsics_ini=waveform_params_ini, # extrinsics_ini=extrinsics_ini, # psd_dir=psd_dir, # ifos=['H1','L1'], # ref_ifo='H1', # downcast=True, # add_noise=True, # coefficient_noise=coefficient_noise, # ) dataset = LFIGWDataset( n=100, data_dir=data_dir, basis_dir=basis_dir, static_args_ini=static_args_ini, data_file='coefficients.npy', intrinsics_ini=waveform_params_ini, extrinsics_ini=extrinsics_ini, psd_dir=psd_dir, ifos=['H1', 'L1'], ref_ifo='H1', downcast=True, add_noise=True, distance_scale=True, time_shift=False, ) sampler = DistributedSampler( dataset, shuffle=False, num_replicas=world_size, rank=rank, seed=rank, ) dataloader = DataLoader( dataset, shuffle=False, num_workers=num_workers, batch_size=batch_size, sampler=sampler, pin_memory=True, persistent_workers=True, prefetch_factor=2, worker_init_fn=dataset._worker_init_fn, collate_fn=dataset._collate_fn, ) # validation data val_dataset = LFIGWDataset( n=100, data_dir=data_dir, basis_dir=basis_dir, static_args_ini=static_args_ini, data_file='coefficients.npy', intrinsics_ini=waveform_params_ini, extrinsics_ini=extrinsics_ini, psd_dir=psd_dir, ifos=['H1', 'L1'], ref_ifo='H1', downcast=True, add_noise=True, coefficient_noise=coefficient_noise, distance_scale=True, time_shift=False, ) # val_dataset = BasisCoefficientsDataset( # data_dir=val_dir, # basis_dir=basis_dir, # static_args_ini=static_args_ini, # parameters_ini=[waveform_params_ini, extrinsics_ini], # coefficient_noise=coefficient_noise, # ) val_sampler = DistributedSampler( val_dataset, shuffle=False, num_replicas=world_size, rank=rank, seed=rank, ) val_loader = DataLoader( val_dataset, shuffle=False, num_workers=num_workers, batch_size=batch_size, sampler=val_sampler, pin_memory=True, prefetch_factor=4, worker_init_fn=val_dataset._worker_init_fn, collate_fn=val_dataset._collate_fn, ) # validation data if interval != 0: # specify indices in validation dataset to validate samples min_idx = val_dataset.parameters.distance.argmin() max_idx = val_dataset.parameters.distance.argmax() median_idx = val_dataset.parameters.loc[ val_dataset.parameters.distance == val_dataset.parameters.distance. quantile(interpolation='nearest')].index[0] if rank == 0: figure_titles = [ 'GW150914', 'Min Distance', f'Median Distance', f'Max Distance' ] # validation ground truths for posterior sampling val_gts = torch.stack([ torch.zeros(len(val_dataset.parameters.columns), dtype=torch.float32), # gw150914 dummy gt torch.tensor(val_dataset.parameters.iloc[min_idx].values, dtype=torch.float32), # rank 1 torch.tensor(val_dataset.parameters.iloc[median_idx].values, dtype=torch.float32), # rank 2 torch.tensor(val_dataset.parameters.iloc[max_idx].values, dtype=torch.float32), # rank 3 ]) with torch.no_grad(): # load data from file manually (rather than using val_dataset._worker_init_fn) val_coefficients = np.load(val_dataset.data_dir / val_dataset.data_file, mmap_mode='c') # generate coefficients on cpu - we want to send this to tensorboard (rank 0) before sending to gpus val_coefficients = torch.cat([ torch.from_numpy( generate_gw150914_context(num_basis, noise_dir, psd_dir, basis_dir, static_args_ini))[None], torch.tensor(val_coefficients[[min_idx, median_idx, max_idx]]), ], dim=0).to(dtype=torch.complex64) # place one of each stacked tensor onto corresponding gpu rank val_context = val_coefficients[ rank] * val_dataset.standardization[:, :num_basis] val_context = val_context.to(device=rank) val_context = torch.cat([val_context.real, val_context.imag], dim=0) val_context = val_context.reshape(val_context.shape[0] * val_context.shape[1])[None] else: figure_titles = None val_gts = None val_coefficients = None # set torch profiling runs # wait = 1 # ignore first batch # warmup = 1 # active = 4 # repeat = 2 # tensorboard if rank == 0: # tb = SummaryWriter(f'gwpe/runs/{log_dir}') queue = mp.SimpleQueue() tb_process = mp.Process(target=tensorboard_writer, args=( queue, f'gwpe/runs/{log_dir}', val_dataset.generator.parameters, val_dataset.generator.latex, static_args_ini, basis_dir, num_basis, val_coefficients, val_gts, figure_titles, )) tb_process.start() # instantiate neural spline coupling flow flow = flows.create_NDE_model( input_dim=14, # we do not predict coalescence time context_dim=4 * num_basis, num_flow_steps=15, base_transform_kwargs={ 'base_transform_type': 'rq-coupling', 'batch_norm': True, 'num_transform_blocks': 10, 'activation': 'elu', }) flow = flow.to(rank) print_peak_memory("Max memory allocated after creating local model", rank) # sync_bn_flow = nn.SyncBatchNorm.convert_sync_batchnorm(flow) flow = DDP(flow, device_ids=[rank], output_device=rank) print_peak_memory("Max memory allocated after creating DDP", rank) if use_zero: #https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html from torch.distributed.optim import ZeroRedundancyOptimizer optimizer = ZeroRedundancyOptimizer( flow.parameters(), optimizer_class=torch.optim.Adam, lr=lr, parameters_as_bucket_view=True, ) # optimizer = torch.optim.Adam(flow.parameters(), lr=lr) else: optimizer = torch.optim.Adam(flow.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) if load_dir is not None and load_epoch is not None: print(f'Loading model from {load_dir} at epoch {load_epoch}.') flow.module.load_state_dict( torch.load(f'gwpe/model_weights/{load_dir}/flow_{load_epoch}.pt', map_location=rank)) optimizer.load_state_dict( torch.load( f'gwpe/model_weights/{load_dir}/optimizer_{load_epoch}.pt', map_location=rank)) if Path(f'gwpe/model_weights/{load_dir}/scheduler_{load_epoch}.pt' ).is_file(): scheduler.load_state_dict( torch.load( f'gwpe/model_weights/{load_dir}/scheduler_{load_epoch}.pt', map_location=rank)) # run training loop flow.train() train_loss = torch.zeros((1, ), device=rank, requires_grad=False) val_loss = torch.zeros((1, ), device=rank, requires_grad=False) disable_pbar = False if verbose and (rank == 0) else True # tqdm progress bar with tqdm(total=len(dataloader) * epochs, disable=disable_pbar, desc=f'[{log_dir}] Training', postfix={'epoch': 0}) as progress: # with torch.profiler.profile( # activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], # schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat), # on_trace_ready=torch.profiler.tensorboard_trace_handler(f'gwpe/runs/{log_dir}'), # record_shapes=True, # with_stack=True # ) as profiler: for epoch in range(1, 1 + epochs): if rank == 0: progress.set_postfix({'epoch': epoch}) progress.set_description(f'[{log_dir}] Training', refresh=True) # let all processes sync up before starting with a new epoch of training flow.train() distributed.barrier() iterator = iter(dataloader) coefficients, parameters = next(iterator) coefficients = coefficients.to(rank, non_blocking=True) parameters = parameters.to(rank, non_blocking=True) complete = False while not complete: optimizer.zero_grad() # if profile: # https://github.com/guyang3532/kineto/blob/readme/tb_plugin/docs/gpu_utilization.md ## WARNING: profiler may not handle async pinned memory transfer properly? # i.e. may not record CPU vs GPU wall times correctly # may be related to reported blocks per SM/achieved occupancy negative bug # this was an open issue for pytorch 1.9 as of july 9 - nightly may fix it # https://github.com/pytorch/kineto/issues/325#issuecomment-869362218 # if (step >= (wait + warmup + active) * repeat): # break # negative log-likelihood conditional on strain over mini-batch loss = -flow.module.log_prob(parameters, context=coefficients).mean() try: # async get data from CPU and move to GPU during model forward coefficients, parameters = next(iterator) coefficients = coefficients.to(rank, non_blocking=True) parameters = parameters.to(rank, non_blocking=True) except StopIteration: # exit while loop if iterator is complete complete = True loss.backward() print_peak_memory( "Max memory allocated before optimizer step()", rank) optimizer.step() print_peak_memory( "Max memory allocated after optimizer step()", rank) # if profile: profiler.step() # total loss summed over each sample in batch train_loss += loss.detach() * coefficients.shape[0] if rank == 0: progress.update(1) scheduler.step() # gather total loss during epoch between each GPU worker as list of tensors world_loss = [ torch.ones_like(train_loss) for _ in range(world_size) ] distributed.all_gather(world_loss, train_loss) train_loss *= 0.0 # reset loss for next epoch if (interval != 0) and (epoch % interval == 0): # evaluate model on validation dataset flow.eval() with torch.no_grad(): iterator = iter(enumerate(val_loader)) step, (coefficients, parameters) = next(iterator) coefficients = coefficients.to(rank, non_blocking=True) parameters = parameters.to(rank, non_blocking=True) if rank == 0: val_progress = int(100 * step / len(val_loader)) progress.set_description( f'[{log_dir}] Validating ({val_progress}%)', refresh=True) complete = False while not complete: # negative log-likelihood conditional on strain over mini-batch loss = -flow.module.log_prob( parameters, context=coefficients).mean() try: # async get data from CPU and move to GPU during model forward step, (coefficients, parameters) = next(iterator) coefficients = coefficients.to(rank, non_blocking=True) parameters = parameters.to(rank, non_blocking=True) if rank == 0: val_progress = int(100 * step / len(val_loader)) progress.set_description( f'[{log_dir}] Validating ({val_progress}%)', refresh=True) except StopIteration: # exit while loop if iterator is complete complete = True # total loss summed over each sample in batch val_loss += loss.detach() * coefficients.shape[0] # gather total loss during epoch between each GPU worker as list of tensors world_val_loss = [ torch.ones_like(val_loss) for _ in range(world_size) ] distributed.all_gather(world_val_loss, val_loss) val_loss *= 0.0 # reset loss for next epoch # validation posteriors if rank == 0: progress.set_description( f'[{log_dir}] Sampling posteriors', refresh=True) samples = flows.sample_flow( flow.module, n=10000, context=val_context, output_device='cuda', dtype=torch.float32, )[0] # gather samples from all gpus world_samples = [ torch.ones_like(samples) for _ in range(world_size) ] distributed.all_gather(world_samples, samples) if (rank == 0): progress.set_description(f'[{log_dir}] Sending to TensorBoard', refresh=True) scalars = { 'loss/train': torch.cat(world_loss).sum().item() / len(dataloader.dataset) } # every "interval" we generate samples for vis, else None corner_samples = None # reset to None for epochs where there is no corner plot if (interval != 0) and (epoch % interval == 0): scalars['loss/validation'] = torch.cat( world_val_loss).sum().item() / len(val_loader.dataset) # convert gw150914 samples to cpu and undo standardization corner_samples = torch.stack(world_samples).cpu() corner_samples *= torch.from_numpy(val_dataset.std) corner_samples += torch.from_numpy(val_dataset.mean) # send data to async process to generate matplotlib figures queue.put((epoch, scalars, corner_samples)) if (save != 0) and (epoch % save == 0): # save checkpoint and write computationally expensive data to tb torch.save(flow.module.state_dict(), experiment_dir / f'flow_{epoch}.pt') # if use_zero: # # needs to be called on all ranks # optimizer.consolidate_state_dict(to=0) torch.save(optimizer.state_dict(), experiment_dir / f'optimizer_{epoch}.pt') if scheduler is not None: torch.save(scheduler.state_dict(), experiment_dir / f'scheduler_{epoch}.pt') # destroy processes from distributed training if rank == 0: # to do - graceful way to shutdown workers # need to send message back from child process sleep_time = 120 for i in range(sleep_time): progress.set_description( f'[{log_dir}] Shutting down in {sleep_time - i}s', refresh=True) time.sleep(1) tb_process.terminate() cleanup_nccl()
def train_distributed( rank: int, world_size: int, lr: float=2e-4, batch_size: int=1000, epochs: int=500, interval: int=10, save: int=0, num_workers: int=4, num_basis: int=100, dataset: str='datasets', coefficient_noise: bool=False, use_zero: bool=False, verbose: bool=False, ): assert 0 < batch_size, "batch_size must be a positive integer." assert 0 < epochs, "epochs must be a positive integer." assert (0 <= interval) and (interval <= epochs), "Interval must be a non-negative integer between 0 and epochs." assert (0 <= save) and (save <= epochs), "Save must be a non-negative integer between 0 and epochs." # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False # setup data distributed parallel training setup_nccl(rank, world_size) # world size is total gpus torch.cuda.set_device(rank) # rank is gpu index # directories data_dir = Path(f'/mnt/datahole/daniel/gravflows/{dataset}/train/') log_dir = f"{datetime.now().strftime('%b%d_%H-%M-%S')}_{os.uname().nodename}" save_dir = Path('lfigw/model_weights/') experiment_dir = save_dir / log_dir experiment_dir.mkdir(parents=True, exist_ok=True) # config files waveform_params_ini = str(data_dir / 'config_files/parameters.ini') extrinsics_ini = 'gwpe/config_files/extrinsics.ini' static_args_ini = 'gwpe/config_files/static_args.ini' # tensorboard if rank == 0: tb = SummaryWriter(f'lfigw/runs/{log_dir}') # training data dataset = lfigwWaveformDataset( n=num_basis, data_dir='lfigw/data/train', basis_dir='lfigw/data/basis/', psd_dir='data/events/GW150914', data_file='coefficients.npy', static_args_ini=static_args_ini, intrinsics_ini=waveform_params_ini, extrinsics_ini=extrinsics_ini, ifos=['H1','L1'], ref_ifo='H1', downcast=True, add_noise=True, coefficient_noise=coefficient_noise, distance_scale=True, time_shift=False, ) sampler = DistributedSampler( dataset, shuffle=True, num_replicas=world_size, rank=rank, seed=rank, ) dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=False, pin_memory=True, persistent_workers=True, sampler=sampler, num_workers=num_workers, prefetch_factor=4, worker_init_fn=dataset._worker_init_fn, collate_fn=dataset._collate_fn, ) # instantiate neural spline coupling flow flow = flows.create_NDE_model( input_dim=14, # we do not predict coalescence time context_dim=4*100, num_flow_steps=15, base_transform_kwargs={ 'base_transform_type': 'rq-coupling', 'batch_norm': True, 'num_transform_blocks': 10, 'activation': 'elu', 'hidden_dim': 512, # default 'dropout_probability': 0.0, # default 'num_bins': 8, # default 'tail_bound': 1.0, # default 'apply_unconditional_transform': False, # default } ) flow = flow.to(rank) # print_peak_memory("Max memory allocated after creating local model", rank) # sync_bn_flow = nn.SyncBatchNorm.convert_sync_batchnorm(flow) flow = DDP(flow, device_ids=[rank], output_device=rank) # print_peak_memory("Max memory allocated after creating DDP", rank) if use_zero: #https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html from torch.distributed.optim import ZeroRedundancyOptimizer optimizer = ZeroRedundancyOptimizer( flow.parameters(), optimizer_class=torch.optim.Adam, lr=lr, parameters_as_bucket_view=True, ) else: optimizer = torch.optim.Adam(flow.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) disable_pbar = False if verbose and (rank == 0) else True # tqdm progress bar with tqdm( total=len(dataloader)*epochs, disable=disable_pbar, ncols=150, postfix={'epoch': 0}, desc=f'[{log_dir}] Training' ) as progress: # run training loop train_loss = torch.zeros((1,), device=rank, requires_grad=False) for epoch in range(1, 1+epochs): flow.train() distributed.barrier() for coefficients, parameters in dataloader: if rank == 0: progress.set_postfix({'epoch': epoch}) progress.set_description(f'[{log_dir}] Training', refresh=True) optimizer.zero_grad() coefficients = coefficients.to(rank, non_blocking=True) parameters = parameters.to(rank, non_blocking=True) # negative log-likelihood conditional on strain over mini-batch loss = -flow.module.log_prob(parameters, context=coefficients).mean() loss.backward() # print_peak_memory("Max memory allocated before optimizer step()", rank) optimizer.step() # print_peak_memory("Max memory allocated after optimizer step()", rank) # total loss summed over each sample in batch (scaled to lfigw) train_loss += loss.detach() * coefficients.shape[0] * (15/14) progress.update(1) scheduler.step() # gather total loss during epoch between each GPU worker as list of tensors world_loss = [torch.ones_like(train_loss) for _ in range(world_size)] distributed.all_gather(world_loss, train_loss) train_loss *= 0.0 # reset loss for next epoch if rank == 0: epoch_loss = torch.cat(world_loss).sum().item() / len(dataloader.dataset) tb.add_scalar('loss/train', epoch_loss, epoch) # if (interval != 0) and (epoch % interval == 0): if (save != 0) and (epoch % save == 0): # save checkpoint and write computationally expensive data to tb torch.save(flow.module.state_dict(), experiment_dir / f'flow_{epoch}.pt') torch.save(optimizer.state_dict(), experiment_dir / f'optimizer_{epoch}.pt') torch.save(scheduler.state_dict(), experiment_dir / f'scheduler_{epoch}.pt') cleanup_nccl()