示例#1
0
def check_logits_one_epoch(config, model, data_loader, epoch, mixup_fn):
    model.eval()

    num_steps = len(data_loader)
    batch_time = AverageMeter()
    meters = defaultdict(AverageMeter)

    start = time.time()
    end = time.time()
    topk = config.DISTILL.LOGITS_TOPK

    for idx, ((samples, targets), (saved_logits_index, saved_logits_value,
                                   seeds)) in enumerate(data_loader):
        samples = samples.cuda(non_blocking=True)
        targets = targets.cuda(non_blocking=True)

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets, seeds)

        with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
            outputs = model(samples)

        softmax_prob = torch.softmax(outputs, -1)

        torch.cuda.synchronize()

        values, indices = softmax_prob.topk(k=topk,
                                            dim=-1,
                                            largest=True,
                                            sorted=True)

        meters['error'].update(
            (values - saved_logits_value.cuda()).abs().mean().item())
        meters['diff_rate'].update(
            torch.count_nonzero(
                (indices != saved_logits_index.cuda())).item() /
            indices.numel())

        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            etas = batch_time.avg * (num_steps - idx)
            extra_meters_str = ''
            for k, v in meters.items():
                extra_meters_str += f'{k} {v.val:.4f} ({v.avg:.4f})\t'
            logger.info(
                f'Check: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
                f'eta {datetime.timedelta(seconds=int(etas))}\t'
                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
                f'{extra_meters_str}'
                f'mem {memory_used:.0f}MB')

    epoch_time = time.time() - start
    logger.info(
        f"EPOCH {epoch} check logits takes {datetime.timedelta(seconds=int(epoch_time))}"
    )
示例#2
0
def validate(config, data_loader, model, logger):

    criterion = torch.nn.CrossEntropyLoss()
    model.eval()

    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    acc1_meter = AverageMeter()
    acc5_meter = AverageMeter()

    end = time.time()
    for idx, (images, target) in enumerate(data_loader):
        images = images.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        output = model(images)

        # measure accuracy and record loss
        loss = criterion(output, target)
        acc1, acc5 = accuracy(output, target, topk=(1, 5))

        loss_meter.update(loss.item(), target.size(0))
        acc1_meter.update(acc1.item(), target.size(0))
        acc5_meter.update(acc5.item(), target.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            logger.info(f'Test: [{idx}/{len(data_loader)}]\t'
                        f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                        f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
                        f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
                        f'Mem {memory_used:.0f}MB')

    loss_meter.sync()
    acc1_meter.sync()
    acc5_meter.sync()

    logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
    return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
示例#3
0
def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch,
                    mixup_fn, lr_scheduler):
    model.train()
    optimizer.zero_grad()

    num_steps = len(data_loader)
    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    norm_meter = AverageMeter()

    acc1_meter = AverageMeter()
    acc5_meter = AverageMeter()

    start = time.time()
    end = time.time()
    for idx, (samples, targets) in enumerate(data_loader):
        samples = samples.cuda(non_blocking=True)
        targets = targets.cuda(non_blocking=True)
        original_targets = targets

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)

        outputs = model(samples)

        with torch.no_grad():
            acc1, acc5 = accuracy(outputs, original_targets, topk=(1, 5))

        if config.TRAIN.ACCUMULATION_STEPS > 1:
            loss = criterion(outputs, targets)
            loss = loss / config.TRAIN.ACCUMULATION_STEPS
            if config.AMP_OPT_LEVEL != "O0":
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm = get_grad_norm(amp.master_params(optimizer))
            else:
                loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm = get_grad_norm(model.parameters())
            if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
                optimizer.step()
                optimizer.zero_grad()
                lr_scheduler.step_update(epoch * num_steps + idx)
        else:
            loss = criterion(outputs, targets)
            optimizer.zero_grad()
            if config.AMP_OPT_LEVEL != "O0":
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm = get_grad_norm(amp.master_params(optimizer))
            else:
                loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm = get_grad_norm(model.parameters())
            optimizer.step()
            lr_scheduler.step_update(epoch * num_steps + idx)

        torch.cuda.synchronize()

        loss_meter.update(loss.item(), targets.size(0))
        norm_meter.update(grad_norm)
        batch_time.update(time.time() - end)
        end = time.time()

        acc1_meter.update(acc1.item(), targets.size(0))
        acc5_meter.update(acc5.item(), targets.size(0))

        if idx % config.PRINT_FREQ == 0:
            lr = optimizer.param_groups[0]['lr']
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            etas = batch_time.avg * (num_steps - idx)
            logger.info(
                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
                f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
                f'Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}\t'
                f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
                f'mem {memory_used:.0f}MB')
    epoch_time = time.time() - start
    logger.info(
        f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}"
    )
示例#4
0
def train_one_epoch_distill(config,
                            model,
                            model_teacher,
                            data_loader,
                            optimizer,
                            epoch,
                            mixup_fn,
                            lr_scheduler,
                            criterion_soft=None,
                            criterion_truth=None,
                            criterion_attn=None,
                            criterion_hidden=None):

    layer_id_s_list = config.DISTILL.STUDENT_LAYER_LIST
    layer_id_t_list = config.DISTILL.TEACHER_LAYER_LIST

    model.train()
    optimizer.zero_grad()

    model_teacher.eval()

    num_steps = len(data_loader)
    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    norm_meter = AverageMeter()
    loss_soft_meter = AverageMeter()
    loss_truth_meter = AverageMeter()
    loss_attn_meter = AverageMeter()
    loss_hidden_meter = AverageMeter()

    acc1_meter = AverageMeter()
    acc5_meter = AverageMeter()
    teacher_acc1_meter = AverageMeter()
    teacher_acc5_meter = AverageMeter()

    start = time.time()
    end = time.time()
    for idx, (samples, targets) in enumerate(data_loader):
        samples = samples.cuda(non_blocking=True)
        targets = targets.cuda(non_blocking=True)
        original_targets = targets

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)

        if config.DISTILL.ATTN_LOSS and config.DISTILL.HIDDEN_LOSS:
            outputs, qkv_s, hidden_s = model(
                samples,
                layer_id_s_list,
                is_attn_loss=True,
                is_hidden_loss=True,
                is_hidden_org=config.DISTILL.HIDDEN_RELATION)
        elif config.DISTILL.ATTN_LOSS:
            outputs, qkv_s = model(
                samples,
                layer_id_s_list,
                is_attn_loss=True,
                is_hidden_loss=False,
                is_hidden_org=config.DISTILL.HIDDEN_RELATION)
        elif config.DISTILL.HIDDEN_LOSS:
            outputs, hidden_s = model(
                samples,
                layer_id_s_list,
                is_attn_loss=False,
                is_hidden_loss=True,
                is_hidden_org=config.DISTILL.HIDDEN_RELATION)
        else:
            outputs = model(samples)

        with torch.no_grad():
            acc1, acc5 = accuracy(outputs, original_targets, topk=(1, 5))
            if config.DISTILL.ATTN_LOSS or config.DISTILL.HIDDEN_LOSS:
                outputs_teacher, qkv_t, hidden_t = model_teacher(
                    samples,
                    layer_id_t_list,
                    is_attn_loss=True,
                    is_hidden_loss=True)
            else:
                outputs_teacher = model_teacher(samples)
            teacher_acc1, teacher_acc5 = accuracy(outputs_teacher,
                                                  original_targets,
                                                  topk=(1, 5))

        if config.TRAIN.ACCUMULATION_STEPS > 1:
            loss_truth = config.DISTILL.ALPHA * criterion_truth(
                outputs, targets)
            loss_soft = (1.0 - config.DISTILL.ALPHA) * criterion_soft(
                outputs / config.DISTILL.TEMPERATURE,
                outputs_teacher / config.DISTILL.TEMPERATURE)
            if config.DISTILL.ATTN_LOSS:
                loss_attn = config.DISTILL.QKV_LOSS_WEIGHT * criterion_attn(
                    qkv_s, qkv_t, config.DISTILL.AR)
            else:
                loss_attn = torch.zeros(loss_truth.shape)
            if config.DISTILL.HIDDEN_LOSS:
                loss_hidden = config.DISTILL.HIDDEN_LOSS_WEIGHT * criterion_hidden(
                    hidden_s, hidden_t)
            else:
                loss_hidden = torch.zeros(loss_truth.shape)
            loss = loss_truth + loss_soft + loss_attn + loss_hidden

            loss = loss / config.TRAIN.ACCUMULATION_STEPS
            if config.AMP_OPT_LEVEL != "O0":
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm = get_grad_norm(amp.master_params(optimizer))
            else:
                loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm = get_grad_norm(model.parameters())
            if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
                optimizer.step()
                optimizer.zero_grad()
                lr_scheduler.step_update(epoch * num_steps + idx)
        else:
            loss_truth = config.DISTILL.ALPHA * criterion_truth(
                outputs, targets)
            loss_soft = (1.0 - config.DISTILL.ALPHA) * criterion_soft(
                outputs / config.DISTILL.TEMPERATURE,
                outputs_teacher / config.DISTILL.TEMPERATURE)
            if config.DISTILL.ATTN_LOSS:
                loss_attn = config.DISTILL.QKV_LOSS_WEIGHT * criterion_attn(
                    qkv_s, qkv_t, config.DISTILL.AR)
            else:
                loss_attn = torch.zeros(loss_truth.shape)
            if config.DISTILL.HIDDEN_LOSS:
                loss_hidden = config.DISTILL.HIDDEN_LOSS_WEIGHT * criterion_hidden(
                    hidden_s, hidden_t)
            else:
                loss_hidden = torch.zeros(loss_truth.shape)
            loss = loss_truth + loss_soft + loss_attn + loss_hidden

            optimizer.zero_grad()
            if config.AMP_OPT_LEVEL != "O0":
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm = get_grad_norm(amp.master_params(optimizer))
            else:
                loss.backward()
                if config.TRAIN.CLIP_GRAD:
                    grad_norm = torch.nn.utils.clip_grad_norm_(
                        model.parameters(), config.TRAIN.CLIP_GRAD)
                else:
                    grad_norm = get_grad_norm(model.parameters())
            optimizer.step()
            lr_scheduler.step_update(epoch * num_steps + idx)

        torch.cuda.synchronize()

        loss_meter.update(loss.item(), targets.size(0))
        loss_soft_meter.update(loss_soft.item(), targets.size(0))
        loss_truth_meter.update(loss_truth.item(), targets.size(0))
        loss_attn_meter.update(loss_attn.item(), targets.size(0))
        loss_hidden_meter.update(loss_hidden.item(), targets.size(0))
        norm_meter.update(grad_norm)
        batch_time.update(time.time() - end)
        end = time.time()

        acc1_meter.update(acc1.item(), targets.size(0))
        acc5_meter.update(acc5.item(), targets.size(0))
        teacher_acc1_meter.update(teacher_acc1.item(), targets.size(0))
        teacher_acc5_meter.update(teacher_acc5.item(), targets.size(0))

        if idx % config.PRINT_FREQ == 0:
            lr = optimizer.param_groups[0]['lr']
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            etas = batch_time.avg * (num_steps - idx)
            logger.info(
                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
                f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
                f'Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}\t'
                f'Teacher_Acc@1 {teacher_acc1_meter.avg:.3f} Teacher_Acc@5 {teacher_acc5_meter.avg:.3f}\t'
                f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                f'loss_soft {loss_soft_meter.val:.4f} ({loss_soft_meter.avg:.4f})\t'
                f'loss_truth {loss_truth_meter.val:.4f} ({loss_truth_meter.avg:.4f})\t'
                f'loss_attn {loss_attn_meter.val:.4f} ({loss_attn_meter.avg:.4f})\t'
                f'loss_hidden {loss_hidden_meter.val:.4f} ({loss_hidden_meter.avg:.4f})\t'
                f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
                f'mem {memory_used:.0f}MB')
    epoch_time = time.time() - start
    logger.info(
        f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}"
    )
示例#5
0
文件: main.py 项目: microsoft/AutoML
def validate(args, config, data_loader, model, num_classes=1000):
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()

    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    acc1_meter = AverageMeter()
    acc5_meter = AverageMeter()

    end = time.time()
    for idx, (images, target) in enumerate(data_loader):
        images = images.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # compute output
        with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
            output = model(images)
        if num_classes == 1000:
            output_num_classes = output.size(-1)
            if output_num_classes == 21841:
                output = remap_layer_22kto1k(output)

        # measure accuracy and record loss
        loss = criterion(output, target)
        acc1, acc5 = accuracy(output, target, topk=(1, 5))

        loss_meter.update(loss.item(), target.size(0))
        acc1_meter.update(acc1.item(), target.size(0))
        acc5_meter.update(acc5.item(), target.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            logger.info(f'Test: [{idx}/{len(data_loader)}]\t'
                        f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                        f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
                        f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
                        f'Mem {memory_used:.0f}MB')

    acc1_meter.sync()
    acc5_meter.sync()
    logger.info(
        f' The number of validation samples is {int(acc1_meter.count)}')
    logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
    return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
示例#6
0
文件: main.py 项目: microsoft/AutoML
def train_one_epoch_distill_using_saved_logits(args, config, model, criterion,
                                               data_loader, optimizer, epoch,
                                               mixup_fn, lr_scheduler,
                                               loss_scaler):
    model.train()
    set_bn_state(config, model)
    optimizer.zero_grad()

    num_steps = len(data_loader)
    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    norm_meter = AverageMeter()
    scaler_meter = AverageMeter()
    meters = defaultdict(AverageMeter)

    start = time.time()
    end = time.time()
    data_tic = time.time()

    num_classes = config.MODEL.NUM_CLASSES
    topk = config.DISTILL.LOGITS_TOPK

    for idx, ((samples, targets), (logits_index, logits_value,
                                   seeds)) in enumerate(data_loader):
        normal_global_idx = epoch * NORM_ITER_LEN + \
            (idx * NORM_ITER_LEN // num_steps)

        samples = samples.cuda(non_blocking=True)
        targets = targets.cuda(non_blocking=True)

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets, seeds)
            original_targets = targets.argmax(dim=1)
        else:
            original_targets = targets
        meters['data_time'].update(time.time() - data_tic)

        with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
            outputs = model(samples)

        # recover teacher logits
        logits_index = logits_index.long()
        logits_value = logits_value.float()
        logits_index = logits_index.cuda(non_blocking=True)
        logits_value = logits_value.cuda(non_blocking=True)
        minor_value = (1.0 - logits_value.sum(-1, keepdim=True)) / (
            num_classes - topk)
        minor_value = minor_value.repeat_interleave(num_classes, dim=-1)
        outputs_teacher = minor_value.scatter_(-1, logits_index, logits_value)

        loss = criterion(outputs, outputs_teacher)
        loss = loss / config.TRAIN.ACCUMULATION_STEPS

        # this attribute is added by timm on one optimizer (adahessian)
        is_second_order = hasattr(
            optimizer, 'is_second_order') and optimizer.is_second_order
        grad_norm = loss_scaler(loss,
                                optimizer,
                                clip_grad=config.TRAIN.CLIP_GRAD,
                                parameters=model.parameters(),
                                create_graph=is_second_order,
                                update_grad=(idx + 1) %
                                config.TRAIN.ACCUMULATION_STEPS == 0)
        if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
            optimizer.zero_grad()
            lr_scheduler.step_update(
                (epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS)
        loss_scale_value = loss_scaler.state_dict()["scale"]

        # compute accuracy
        real_batch_size = len(original_targets)
        acc1, acc5 = accuracy(outputs, original_targets, topk=(1, 5))
        meters['train_acc1'].update(acc1.item(), real_batch_size)
        meters['train_acc5'].update(acc5.item(), real_batch_size)
        teacher_acc1, teacher_acc5 = accuracy(outputs_teacher,
                                              original_targets,
                                              topk=(1, 5))
        meters['teacher_acc1'].update(teacher_acc1.item(), real_batch_size)
        meters['teacher_acc5'].update(teacher_acc5.item(), real_batch_size)

        torch.cuda.synchronize()

        loss_meter.update(loss.item(), real_batch_size)
        if is_valid_grad_norm(grad_norm):
            norm_meter.update(grad_norm)
        scaler_meter.update(loss_scale_value)
        batch_time.update(time.time() - end)
        end = time.time()
        data_tic = time.time()

        if idx % config.PRINT_FREQ == 0:
            lr = optimizer.param_groups[0]['lr']
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            etas = batch_time.avg * (num_steps - idx)

            extra_meters_str = ''
            for k, v in meters.items():
                extra_meters_str += f'{k} {v.val:.4f} ({v.avg:.4f})\t'
            logger.info(
                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
                f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
                f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
                f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t'
                f'{extra_meters_str}'
                f'mem {memory_used:.0f}MB')

            if is_main_process() and args.use_wandb:
                acc1_meter, acc5_meter = meters['train_acc1'], meters[
                    'train_acc5']
                wandb.log(
                    {
                        "train/acc@1": acc1_meter.val,
                        "train/acc@5": acc5_meter.val,
                        "train/loss": loss_meter.val,
                        "train/grad_norm": norm_meter.val,
                        "train/loss_scale": scaler_meter.val,
                        "train/lr": lr,
                    },
                    step=normal_global_idx)
    epoch_time = time.time() - start
    extra_meters_str = f'Train-Summary: [{epoch}/{config.TRAIN.EPOCHS}]\t'
    for k, v in meters.items():
        v.sync()
        extra_meters_str += f'{k} {v.val:.4f} ({v.avg:.4f})\t'
    logger.info(extra_meters_str)
    logger.info(
        f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}"
    )
示例#7
0
文件: main.py 项目: microsoft/AutoML
def train_one_epoch(args, config, model, criterion, data_loader, optimizer,
                    epoch, mixup_fn, lr_scheduler, loss_scaler):
    model.train()
    set_bn_state(config, model)
    optimizer.zero_grad()

    num_steps = len(data_loader)
    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    norm_meter = AverageMeter()
    scaler_meter = AverageMeter()
    acc1_meter = AverageMeter()
    acc5_meter = AverageMeter()

    start = time.time()
    end = time.time()
    for idx, (samples, targets) in enumerate(data_loader):
        normal_global_idx = epoch * NORM_ITER_LEN + \
            (idx * NORM_ITER_LEN // num_steps)

        samples = samples.cuda(non_blocking=True)
        targets = targets.cuda(non_blocking=True)

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets)
            original_targets = targets.argmax(dim=1)
        else:
            original_targets = targets

        with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
            outputs = model(samples)

        loss = criterion(outputs, targets)
        loss = loss / config.TRAIN.ACCUMULATION_STEPS

        # this attribute is added by timm on one optimizer (adahessian)
        is_second_order = hasattr(
            optimizer, 'is_second_order') and optimizer.is_second_order
        grad_norm = loss_scaler(loss,
                                optimizer,
                                clip_grad=config.TRAIN.CLIP_GRAD,
                                parameters=model.parameters(),
                                create_graph=is_second_order,
                                update_grad=(idx + 1) %
                                config.TRAIN.ACCUMULATION_STEPS == 0)
        if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
            optimizer.zero_grad()
            lr_scheduler.step_update(
                (epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS)
        loss_scale_value = loss_scaler.state_dict()["scale"]

        with torch.no_grad():
            acc1, acc5 = accuracy(outputs, original_targets, topk=(1, 5))
        acc1_meter.update(acc1.item(), targets.size(0))
        acc5_meter.update(acc5.item(), targets.size(0))

        torch.cuda.synchronize()

        loss_meter.update(loss.item(), targets.size(0))
        if is_valid_grad_norm(grad_norm):
            norm_meter.update(grad_norm)
        scaler_meter.update(loss_scale_value)
        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            lr = optimizer.param_groups[0]['lr']
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            etas = batch_time.avg * (num_steps - idx)
            logger.info(
                f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
                f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
                f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
                f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t'
                f'mem {memory_used:.0f}MB')

            if is_main_process() and args.use_wandb:
                wandb.log(
                    {
                        "train/acc@1": acc1_meter.val,
                        "train/acc@5": acc5_meter.val,
                        "train/loss": loss_meter.val,
                        "train/grad_norm": norm_meter.val,
                        "train/loss_scale": scaler_meter.val,
                        "train/lr": lr,
                    },
                    step=normal_global_idx)
    epoch_time = time.time() - start
    logger.info(
        f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}"
    )
示例#8
0
def save_logits_one_epoch(config, model, data_loader, epoch, mixup_fn):
    model.eval()

    num_steps = len(data_loader)
    batch_time = AverageMeter()
    meters = defaultdict(AverageMeter)

    start = time.time()
    end = time.time()
    topk = config.DISTILL.LOGITS_TOPK

    logits_manager = data_loader.dataset.get_manager()

    for idx, ((samples, targets), (keys, seeds)) in enumerate(data_loader):
        samples = samples.cuda(non_blocking=True)
        targets = targets.cuda(non_blocking=True)

        if mixup_fn is not None:
            samples, targets = mixup_fn(samples, targets, seeds)
            original_targets = targets.argmax(dim=1)
        else:
            original_targets = targets

        with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
            outputs = model(samples)

        acc1, acc5 = accuracy(outputs, original_targets, topk=(1, 5))
        real_batch_size = len(samples)
        meters['teacher_acc1'].update(acc1.item(), real_batch_size)
        meters['teacher_acc5'].update(acc5.item(), real_batch_size)

        # save teacher logits
        softmax_prob = torch.softmax(outputs, -1)

        torch.cuda.synchronize()

        write_tic = time.time()
        values, indices = softmax_prob.topk(k=topk,
                                            dim=-1,
                                            largest=True,
                                            sorted=True)

        cpu_device = torch.device('cpu')
        values = values.detach().to(device=cpu_device, dtype=torch.float16)
        indices = indices.detach().to(device=cpu_device, dtype=torch.int16)

        seeds = seeds.numpy()
        values = values.numpy()
        indices = indices.numpy()

        # check data type
        assert seeds.dtype == np.int32, seeds.dtype
        assert indices.dtype == np.int16, indices.dtype
        assert values.dtype == np.float16, values.dtype

        for key, seed, indice, value in zip(keys, seeds, indices, values):
            bstr = seed.tobytes() + indice.tobytes() + value.tobytes()
            logits_manager.write(key, bstr)
        meters['write_time'].update(time.time() - write_tic)

        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            etas = batch_time.avg * (num_steps - idx)
            extra_meters_str = ''
            for k, v in meters.items():
                extra_meters_str += f'{k} {v.val:.4f} ({v.avg:.4f})\t'
            logger.info(
                f'Save: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
                f'eta {datetime.timedelta(seconds=int(etas))}\t'
                f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
                f'{extra_meters_str}'
                f'mem {memory_used:.0f}MB')

    epoch_time = time.time() - start
    logger.info(
        f"EPOCH {epoch} save logits takes {datetime.timedelta(seconds=int(epoch_time))}"
    )