Esempio n. 1
0
def test_model():
    """Test a model"""

    utils.setup_distributed()
    rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    device = torch.device("cuda", local_rank)

    utils.setup_logger(rank, local_rank)

    try:
        net = models.build_model(
            arch=cfg.MODEL.ARCH,
            pretrained=cfg.MODEL.PRETRAINED,
            num_classes=cfg.MODEL.NUM_CLASSES,
        )
    except KeyError:
        net = timm.create_model(
            model_name=cfg.MODEL.ARCH,
            pretrained=cfg.MODEL.PRETRAINED,
            num_classes=cfg.MODEL.NUM_CLASSES,
        )
    net = net.to(device)
    net = DDP(net, device_ids=[local_rank], output_device=local_rank)

    val_loader = utils.construct_val_loader()
    criterion = nn.CrossEntropyLoss().to(device)

    if cfg.MODEL.WEIGHTS:
        utils.load_checkpoint(cfg.MODEL.WEIGHTS, net)

    acc1, acck = validate(val_loader, net, criterion)
    if rank == 0:
        logger.info(
            f"ACCURACY: TOP1 {acc1:.3f}  |  TOP{cfg.TRAIN.TOPK} {acck:.3f}")
Esempio n. 2
0
def train_model():
    """Train a model"""

    # Set up distributed device
    utils.setup_distributed()
    rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    device = torch.device("cuda", local_rank)

    utils.setup_seed(rank)
    utils.setup_logger(rank, local_rank)

    net = models.build_model(arch=cfg.MODEL.ARCH,
                             pretrained=cfg.MODEL.PRETRAINED)
    # SyncBN (https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html)
    net = nn.SyncBatchNorm.convert_sync_batchnorm(
        net) if cfg.MODEL.SYNCBN else net
    net = net.to(device)
    # DistributedDataParallel Wrapper
    net = DDP(net, device_ids=[local_rank], output_device=local_rank)

    train_loader = utils.construct_train_loader()
    val_loader = utils.construct_val_loader()

    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = utils.construct_optimizer(net)

    # Resume from a specific checkpoint or the last checkpoint
    best_acc1 = start_epoch = 0
    if cfg.TRAIN.AUTO_RESUME and utils.has_checkpoint():
        file = utils.get_last_checkpoint()
        start_epoch, best_acc1 = utils.load_checkpoint(file, net, optimizer)
    elif cfg.MODEL.WEIGHTS:
        load_opt = optimizer if cfg.TRAIN.LOAD_OPT else None
        start_epoch, best_acc1 = utils.load_checkpoint(cfg.MODEL.WEIGHTS, net,
                                                       load_opt)

    if rank == 0:
        # from torch.utils.collect_env import get_pretty_env_info
        # logger.debug(get_pretty_env_info())
        # logger.debug(net)
        logger.info("\n\n\n            =======  TRAINING  ======= \n\n")

    for epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
        # Train one epoch
        train_epoch(train_loader, net, criterion, optimizer, epoch)
        # Validate
        acc1, acck = validate(val_loader, net, criterion)
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        # Save model
        checkpoint_file = utils.save_checkpoint(net, optimizer, epoch,
                                                best_acc1, is_best)
        if rank == 0:
            logger.info(
                f"ACCURACY: TOP1 {acc1:.3f}(BEST {best_acc1:.3f}) | TOP{cfg.TRAIN.TOPK} {acck:.3f} | SAVED {checkpoint_file}"
            )