def test_model(): """Test a model""" utils.setup_distributed() rank = int(os.environ["RANK"]) local_rank = int(os.environ["LOCAL_RANK"]) device = torch.device("cuda", local_rank) utils.setup_logger(rank, local_rank) try: net = models.build_model( arch=cfg.MODEL.ARCH, pretrained=cfg.MODEL.PRETRAINED, num_classes=cfg.MODEL.NUM_CLASSES, ) except KeyError: net = timm.create_model( model_name=cfg.MODEL.ARCH, pretrained=cfg.MODEL.PRETRAINED, num_classes=cfg.MODEL.NUM_CLASSES, ) net = net.to(device) net = DDP(net, device_ids=[local_rank], output_device=local_rank) val_loader = utils.construct_val_loader() criterion = nn.CrossEntropyLoss().to(device) if cfg.MODEL.WEIGHTS: utils.load_checkpoint(cfg.MODEL.WEIGHTS, net) acc1, acck = validate(val_loader, net, criterion) if rank == 0: logger.info( f"ACCURACY: TOP1 {acc1:.3f} | TOP{cfg.TRAIN.TOPK} {acck:.3f}")
def train_model(): """Train a model""" # Set up distributed device utils.setup_distributed() rank = int(os.environ["RANK"]) local_rank = int(os.environ["LOCAL_RANK"]) device = torch.device("cuda", local_rank) utils.setup_seed(rank) utils.setup_logger(rank, local_rank) net = models.build_model(arch=cfg.MODEL.ARCH, pretrained=cfg.MODEL.PRETRAINED) # SyncBN (https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html) net = nn.SyncBatchNorm.convert_sync_batchnorm( net) if cfg.MODEL.SYNCBN else net net = net.to(device) # DistributedDataParallel Wrapper net = DDP(net, device_ids=[local_rank], output_device=local_rank) train_loader = utils.construct_train_loader() val_loader = utils.construct_val_loader() criterion = nn.CrossEntropyLoss().to(device) optimizer = utils.construct_optimizer(net) # Resume from a specific checkpoint or the last checkpoint best_acc1 = start_epoch = 0 if cfg.TRAIN.AUTO_RESUME and utils.has_checkpoint(): file = utils.get_last_checkpoint() start_epoch, best_acc1 = utils.load_checkpoint(file, net, optimizer) elif cfg.MODEL.WEIGHTS: load_opt = optimizer if cfg.TRAIN.LOAD_OPT else None start_epoch, best_acc1 = utils.load_checkpoint(cfg.MODEL.WEIGHTS, net, load_opt) if rank == 0: # from torch.utils.collect_env import get_pretty_env_info # logger.debug(get_pretty_env_info()) # logger.debug(net) logger.info("\n\n\n ======= TRAINING ======= \n\n") for epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): # Train one epoch train_epoch(train_loader, net, criterion, optimizer, epoch) # Validate acc1, acck = validate(val_loader, net, criterion) is_best = acc1 > best_acc1 best_acc1 = max(acc1, best_acc1) # Save model checkpoint_file = utils.save_checkpoint(net, optimizer, epoch, best_acc1, is_best) if rank == 0: logger.info( f"ACCURACY: TOP1 {acc1:.3f}(BEST {best_acc1:.3f}) | TOP{cfg.TRAIN.TOPK} {acck:.3f} | SAVED {checkpoint_file}" )