Example #1
0
    def __init__(self, **dataloader_kwargs):
        super().__init__()
        if get_local_rank() == 0:
            self.prepare_data()

        # Wait until rank zero has prepared the data (download, preprocessing, ...)
        if dist.is_initialized():
            dist.barrier(device_ids=[get_local_rank()])

        self.dataloader_kwargs = {
            'pin_memory': True,
            'persistent_workers': dataloader_kwargs.get('num_workers', 0) > 0,
            **dataloader_kwargs
        }
        self.ds_train, self.ds_val, self.ds_test = None, None, None
Example #2
0
def save_state(model: nn.Module, optimizer: Optimizer, epoch: int,
               path: pathlib.Path, callbacks: List[BaseCallback]):
    """ Saves model, optimizer and epoch states to path (only once per node) """
    if get_local_rank() == 0:
        state_dict = model.module.state_dict() if isinstance(
            model, DistributedDataParallel) else model.state_dict()
        checkpoint = {
            'state_dict': state_dict,
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch
        }
        for callback in callbacks:
            callback.on_checkpoint_save(checkpoint)

        torch.save(checkpoint, str(path))
        logging.info(f'Saved checkpoint to {str(path)}')
def evaluate(model: nn.Module, dataloader: DataLoader,
             callbacks: List[BaseCallback], args):
    model.eval()
    for i, batch in tqdm(enumerate(dataloader),
                         total=len(dataloader),
                         unit='batch',
                         desc=f'Evaluation',
                         leave=False,
                         disable=(args.silent or get_local_rank() != 0)):
        *input, target = to_cuda(batch)

        for callback in callbacks:
            callback.on_batch_start()

        with torch.cuda.amp.autocast(enabled=args.amp):
            pred = model(*input)

            for callback in callbacks:
                callback.on_validation_step(input, target, pred)
Example #4
0
 def load(self):
     super().load()
     # Iterate through the dataset and compute bases (pairwise only)
     # Potential improvement: use multi-GPU and gather
     dataloader = DataLoader(self,
                             shuffle=False,
                             batch_size=self.batch_size,
                             num_workers=self.num_workers,
                             collate_fn=lambda samples: dgl.batch(
                                 [sample[0] for sample in samples]))
     bases = []
     for i, graph in tqdm(enumerate(dataloader),
                          total=len(dataloader),
                          desc='Precomputing QM9 bases',
                          disable=get_local_rank() != 0):
         rel_pos = _get_relative_pos(graph)
         # Compute the bases with the GPU but convert the result to CPU to store in RAM
         bases.append({
             k: v.cpu()
             for k, v in get_basis(rel_pos.cuda(), **
                                   self.bases_kwargs).items()
         })
     self.bases = bases  # Assign at the end so that __getitem__ isn't confused
Example #5
0
def train(model: nn.Module, loss_fn: _Loss, train_dataloader: DataLoader,
          val_dataloader: DataLoader, callbacks: List[BaseCallback],
          logger: Logger, args):
    device = torch.cuda.current_device()
    model.to(device=device)
    local_rank = get_local_rank()
    world_size = dist.get_world_size() if dist.is_initialized() else 1

    if dist.is_initialized():
        model = DistributedDataParallel(model,
                                        device_ids=[local_rank],
                                        output_device=local_rank)
        model._set_static_graph()

    model.train()
    grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
    if args.optimizer == 'adam':
        optimizer = FusedAdam(model.parameters(),
                              lr=args.learning_rate,
                              betas=(args.momentum, 0.999),
                              weight_decay=args.weight_decay)
    elif args.optimizer == 'lamb':
        optimizer = FusedLAMB(model.parameters(),
                              lr=args.learning_rate,
                              betas=(args.momentum, 0.999),
                              weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    epoch_start = load_state(model, optimizer, args.load_ckpt_path,
                             callbacks) if args.load_ckpt_path else 0

    for callback in callbacks:
        callback.on_fit_start(optimizer, args)

    for epoch_idx in range(epoch_start, args.epochs):
        if isinstance(train_dataloader.sampler, DistributedSampler):
            train_dataloader.sampler.set_epoch(epoch_idx)

        loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx,
                           grad_scaler, optimizer, local_rank, callbacks, args)
        if dist.is_initialized():
            loss = torch.tensor(loss, dtype=torch.float, device=device)
            torch.distributed.all_reduce(loss)
            loss = (loss / world_size).item()

        logging.info(f'Train loss: {loss}')
        logger.log_metrics({'train loss': loss}, epoch_idx)

        for callback in callbacks:
            callback.on_epoch_end()

        if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \
                and (epoch_idx + 1) % args.ckpt_interval == 0:
            save_state(model, optimizer, epoch_idx, args.save_ckpt_path,
                       callbacks)

        if not args.benchmark and ((args.eval_interval > 0 and
                                    (epoch_idx + 1) % args.eval_interval == 0)
                                   or epoch_idx + 1 == args.epochs):
            evaluate(model, val_dataloader, callbacks, args)
            model.train()

            for callback in callbacks:
                callback.on_validation_end(epoch_idx)

    if args.save_ckpt_path is not None and not args.benchmark:
        save_state(model, optimizer, args.epochs, args.save_ckpt_path,
                   callbacks)

    for callback in callbacks:
        callback.on_fit_end()
Example #6
0
        save_state(model, optimizer, args.epochs, args.save_ckpt_path,
                   callbacks)

    for callback in callbacks:
        callback.on_fit_end()


def print_parameters_count(model):
    num_params_trainable = sum(p.numel() for p in model.parameters()
                               if p.requires_grad)
    logging.info(f'Number of trainable parameters: {num_params_trainable}')


if __name__ == '__main__':
    is_distributed = init_distributed()
    local_rank = get_local_rank()
    args = PARSER.parse_args()

    logging.getLogger().setLevel(
        logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)

    logging.info('====== SE(3)-Transformer ======')
    logging.info('|      Training procedure     |')
    logging.info('===============================')

    if args.seed is not None:
        logging.info(f'Using seed {args.seed}')
        seed_everything(args.seed)

    loggers = [DLLogger(save_dir=args.log_dir, filename=args.dllogger_name)]
    if args.wandb:
            for callback in callbacks:
                callback.on_validation_step(input, target, pred)


if __name__ == '__main__':
    from se3_transformer.runtime.callbacks import QM9MetricCallback, PerformanceCallback
    from se3_transformer.runtime.utils import init_distributed, seed_everything
    from se3_transformer.model import SE3TransformerPooled, Fiber
    from se3_transformer.data_loading import QM9DataModule
    import torch.distributed as dist
    import logging
    import sys

    is_distributed = init_distributed()
    local_rank = get_local_rank()
    args = PARSER.parse_args()

    logging.getLogger().setLevel(
        logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)

    logging.info('====== SE(3)-Transformer ======')
    logging.info('|  Inference on the test set  |')
    logging.info('===============================')

    if not args.benchmark and args.load_ckpt_path is None:
        logging.error(
            'No load_ckpt_path provided, you need to provide a saved model to evaluate'
        )
        sys.exit(1)