示例#1
0
def train_epoch(loader, model, loss_fun, optimizer, scaler, meter, cur_epoch):
    """Performs one epoch of training."""
    # Shuffle the data
    data_loader.shuffle(loader, cur_epoch)
    # Update the learning rate
    lr = optim.get_epoch_lr(cur_epoch)
    optim.set_lr(optimizer, lr)
    # Enable training mode
    model.train()
    meter.reset()
    meter.iter_tic()
    for cur_iter, (inputs, labels) in enumerate(loader):
        # Transfer the data to the current GPU device
        inputs = inputs.npu()
        labels = labels.to(torch.int32).npu()
        labels = labels.to(non_blocking=False)
        # Convert labels to smoothed one-hot vector
        p_labels = labels[:]
        labels_one_hot = net.smooth_one_hot_labels(labels).npu()
        # Apply mixup to the batch (no effect if mixup alpha is 0)
        inputs, labels_one_hot, labels = net.mixup(inputs, labels_one_hot)
        # Perform the forward pass and compute the loss
        preds = model(inputs)
        loss = loss_fun(preds, labels_one_hot)
        stream = torch.npu.current_stream()
        stream.synchronize()
        # Perform the backward pass and update the parameters
        optimizer.zero_grad()
        stream = torch.npu.current_stream()
        stream.synchronize()
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
        stream = torch.npu.current_stream()
        stream.synchronize()
        optimizer.step()
        stream = torch.npu.current_stream()
        stream.synchronize()
        # Compute the errors
        top1_err, top5_err = meters.topk_errors(preds, p_labels, [1, 5])
        # Combine the errors across the GPUs  (no reduction if 1 GPU used)
        # Combine the stats across the GPUs (no reduction if 1 GPU used)
        # loss, top1_err, top5_err = dist.scaled_all_reduce([loss, top1_err, top5_err])
        # Copy the stats from GPU to CPU (sync point)
        loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item(
        )
        meter.iter_toc()
        # Update and log stats
        mb_size = inputs.size(0) * cfg.NUM_GPUS
        meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
        meter.log_iter_stats(cur_epoch, cur_iter)
        meter.iter_tic()
    # Log epoch stats
    meter.log_epoch_stats(cur_epoch)
示例#2
0
def train_epoch(loader, model, ema, loss_fun, optimizer, scaler, meter,
                cur_epoch):
    """Performs one epoch of training."""
    # Shuffle the data
    data_loader.shuffle(loader, cur_epoch)
    # Update the learning rate
    lr = optim.get_epoch_lr(cur_epoch)
    optim.set_lr(optimizer, lr)
    # Enable training mode
    model.train()
    ema.train()
    meter.reset()
    meter.iter_tic()
    for cur_iter, (inputs, labels) in enumerate(loader):
        # Transfer the data to the current GPU device
        inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
        # Convert labels to smoothed one-hot vector
        labels_one_hot = net.smooth_one_hot_labels(labels)
        # Apply mixup to the batch (no effect if mixup alpha is 0)
        inputs, labels_one_hot, labels = net.mixup(inputs, labels_one_hot)
        # Perform the forward pass and compute the loss
        with amp.autocast(enabled=cfg.TRAIN.MIXED_PRECISION):
            preds = model(inputs)
            loss = loss_fun(preds, labels_one_hot)
        # Perform the backward pass and update the parameters
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        # Update ema weights
        net.update_model_ema(model, ema, cur_epoch, cur_iter)
        # Compute the errors
        top1_err, top5_err = meters.topk_errors(preds, labels, [1, 5])
        # Combine the stats across the GPUs (no reduction if 1 GPU used)
        loss, top1_err, top5_err = dist.scaled_all_reduce(
            [loss, top1_err, top5_err])
        # Copy the stats from GPU to CPU (sync point)
        loss, top1_err, top5_err = loss.item(), top1_err.item(), top5_err.item(
        )
        meter.iter_toc()
        # Update and log stats
        mb_size = inputs.size(0) * cfg.NUM_GPUS
        meter.update_stats(top1_err, top5_err, loss, lr, mb_size)
        meter.log_iter_stats(cur_epoch, cur_iter)
        meter.iter_tic()
    # Log epoch stats
    meter.log_epoch_stats(cur_epoch)