Пример #1
0
def train_epoch(train_loader, net, criterion, optimizer, cur_epoch,
                start_epoch, tic):
    """Train one epoch"""
    rank = torch.distributed.get_rank()
    batch_time, data_time, losses, top1, topk = utils.construct_meters()
    progress = utils.ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, topk],
        prefix=f"TRAIN:  [{cur_epoch+1}]",
    )

    # Set learning rate
    lr = utils.get_epoch_lr(cur_epoch)
    utils.set_lr(optimizer, lr)
    if rank == 0:
        logger.debug(
            f"CURRENT EPOCH: {cur_epoch+1:3d},   LR: {lr:.4f},   POLICY: {cfg.OPTIM.LR_POLICY}"
        )

    # Set sampler
    train_loader.sampler.set_epoch(cur_epoch)

    net.train()
    end = time.time()
    for idx, (inputs, targets) in enumerate(train_loader):
        data_time.update(time.time() - end)

        inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)

        outputs = net(inputs)
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_size = inputs.size(0)
        acc_1, acc_k = utils.accuracy(outputs,
                                      targets,
                                      topk=(1, cfg.TRAIN.TOPK))
        loss, acc_1, acc_k = utils.scaled_all_reduce([loss, acc_1, acc_k])

        losses.update(loss.item(), batch_size)
        top1.update(acc_1[0].item(), batch_size)
        topk.update(acc_k[0].item(), batch_size)

        batch_time.update(time.time() - end)
        end = time.time()

        if rank == 0 and ((idx + 1) % cfg.TRAIN.PRINT_FREQ == 0 or
                          (idx + 1) == len(train_loader)):
            progress.cal_eta(idx + 1, len(train_loader), tic, cur_epoch,
                             start_epoch)
            progress.display(idx + 1)
Пример #2
0
def validate(val_loader, net, criterion):
    """Validte the model"""
    rank = torch.distributed.get_rank()
    batch_time, data_time, losses, top1, topk = utils.construct_meters()
    progress = utils.ProgressMeter(
        len(val_loader),
        [batch_time, data_time, losses, top1, topk],
        prefix="VAL:  ",
    )

    net.eval()
    with torch.no_grad():
        end = time.time()
        for idx, (inputs, targets) in enumerate(val_loader):
            data_time.update(time.time() - end)

            inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
            outputs = net(inputs)

            loss = criterion(outputs, targets)

            acc_1, acc_k = utils.accuracy(outputs,
                                          targets,
                                          topk=(1, cfg.TRAIN.TOPK))
            loss, acc_1, acc_k = utils.scaled_all_reduce([loss, acc_1, acc_k])

            batch_size = inputs.size(0)
            losses.update(loss.item(), batch_size)
            top1.update(acc_1[0].item(), batch_size)
            topk.update(acc_k[0].item(), batch_size)
            batch_time.update(time.time() - end)
            end = time.time()

            if rank == 0 and ((idx + 1) % cfg.TEST.PRINT_FREQ == 0 or
                              (idx + 1) == len(val_loader)):
                progress.display(idx + 1)

    return top1.avg, topk.avg