def __init__(self, **dataloader_kwargs): super().__init__() if get_local_rank() == 0: self.prepare_data() # Wait until rank zero has prepared the data (download, preprocessing, ...) if dist.is_initialized(): dist.barrier(device_ids=[get_local_rank()]) self.dataloader_kwargs = { 'pin_memory': True, 'persistent_workers': dataloader_kwargs.get('num_workers', 0) > 0, **dataloader_kwargs } self.ds_train, self.ds_val, self.ds_test = None, None, None
def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]): """ Saves model, optimizer and epoch states to path (only once per node) """ if get_local_rank() == 0: state_dict = model.module.state_dict() if isinstance( model, DistributedDataParallel) else model.state_dict() checkpoint = { 'state_dict': state_dict, 'optimizer_state_dict': optimizer.state_dict(), 'epoch': epoch } for callback in callbacks: callback.on_checkpoint_save(checkpoint) torch.save(checkpoint, str(path)) logging.info(f'Saved checkpoint to {str(path)}')
def evaluate(model: nn.Module, dataloader: DataLoader, callbacks: List[BaseCallback], args): model.eval() for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), unit='batch', desc=f'Evaluation', leave=False, disable=(args.silent or get_local_rank() != 0)): *input, target = to_cuda(batch) for callback in callbacks: callback.on_batch_start() with torch.cuda.amp.autocast(enabled=args.amp): pred = model(*input) for callback in callbacks: callback.on_validation_step(input, target, pred)
def load(self): super().load() # Iterate through the dataset and compute bases (pairwise only) # Potential improvement: use multi-GPU and gather dataloader = DataLoader(self, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, collate_fn=lambda samples: dgl.batch( [sample[0] for sample in samples])) bases = [] for i, graph in tqdm(enumerate(dataloader), total=len(dataloader), desc='Precomputing QM9 bases', disable=get_local_rank() != 0): rel_pos = _get_relative_pos(graph) # Compute the bases with the GPU but convert the result to CPU to store in RAM bases.append({ k: v.cpu() for k, v in get_basis(rel_pos.cuda(), ** self.bases_kwargs).items() }) self.bases = bases # Assign at the end so that __getitem__ isn't confused
def train(model: nn.Module, loss_fn: _Loss, train_dataloader: DataLoader, val_dataloader: DataLoader, callbacks: List[BaseCallback], logger: Logger, args): device = torch.cuda.current_device() model.to(device=device) local_rank = get_local_rank() world_size = dist.get_world_size() if dist.is_initialized() else 1 if dist.is_initialized(): model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) model._set_static_graph() model.train() grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp) if args.optimizer == 'adam': optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999), weight_decay=args.weight_decay) elif args.optimizer == 'lamb': optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999), weight_decay=args.weight_decay) else: optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0 for callback in callbacks: callback.on_fit_start(optimizer, args) for epoch_idx in range(epoch_start, args.epochs): if isinstance(train_dataloader.sampler, DistributedSampler): train_dataloader.sampler.set_epoch(epoch_idx) loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args) if dist.is_initialized(): loss = torch.tensor(loss, dtype=torch.float, device=device) torch.distributed.all_reduce(loss) loss = (loss / world_size).item() logging.info(f'Train loss: {loss}') logger.log_metrics({'train loss': loss}, epoch_idx) for callback in callbacks: callback.on_epoch_end() if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \ and (epoch_idx + 1) % args.ckpt_interval == 0: save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks) if not args.benchmark and ((args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0) or epoch_idx + 1 == args.epochs): evaluate(model, val_dataloader, callbacks, args) model.train() for callback in callbacks: callback.on_validation_end(epoch_idx) if args.save_ckpt_path is not None and not args.benchmark: save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks) for callback in callbacks: callback.on_fit_end()
save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks) for callback in callbacks: callback.on_fit_end() def print_parameters_count(model): num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) logging.info(f'Number of trainable parameters: {num_params_trainable}') if __name__ == '__main__': is_distributed = init_distributed() local_rank = get_local_rank() args = PARSER.parse_args() logging.getLogger().setLevel( logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO) logging.info('====== SE(3)-Transformer ======') logging.info('| Training procedure |') logging.info('===============================') if args.seed is not None: logging.info(f'Using seed {args.seed}') seed_everything(args.seed) loggers = [DLLogger(save_dir=args.log_dir, filename=args.dllogger_name)] if args.wandb:
for callback in callbacks: callback.on_validation_step(input, target, pred) if __name__ == '__main__': from se3_transformer.runtime.callbacks import QM9MetricCallback, PerformanceCallback from se3_transformer.runtime.utils import init_distributed, seed_everything from se3_transformer.model import SE3TransformerPooled, Fiber from se3_transformer.data_loading import QM9DataModule import torch.distributed as dist import logging import sys is_distributed = init_distributed() local_rank = get_local_rank() args = PARSER.parse_args() logging.getLogger().setLevel( logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO) logging.info('====== SE(3)-Transformer ======') logging.info('| Inference on the test set |') logging.info('===============================') if not args.benchmark and args.load_ckpt_path is None: logging.error( 'No load_ckpt_path provided, you need to provide a saved model to evaluate' ) sys.exit(1)