예제 #1
0
def validate(val_loader, net):
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    net.eval()

    prefetcher = DataPrefetcher(val_loader)
    inputs, labels = prefetcher.next()
    with torch.no_grad():
        while inputs is not None:
            inputs = inputs.float().cuda()
            labels = labels.cuda()

            stu_outputs, _ = net(inputs)

            acc1, acc5 = accuracy(stu_outputs[-1], labels, topk=(1, 5))
            top1.update(acc1.item(), inputs.size(0))
            top5.update(acc5.item(), inputs.size(0))
            inputs, labels = prefetcher.next()

    return top1.avg, top5.avg
예제 #2
0
def train(train_loader, net, criterion, optimizer, scheduler, epoch, logger):
    top1 = AverageMeter()
    top5 = AverageMeter()
    loss_total = AverageMeter()

    loss_ams = [AverageMeter()] * len(criterion)
    loss_alphas = []
    for loss_item in Config.loss_list:
        loss_rate = loss_item["loss_rate"]
        factor = loss_item["factor"]
        loss_type = loss_item["loss_type"]
        loss_rate_decay = loss_item["loss_rate_decay"]
        loss_alphas.append(
            adjust_loss_alpha(loss_rate, epoch, factor, loss_type,
                              loss_rate_decay))

    # switch to train mode
    net.train()

    iters = len(train_loader.dataset) // args.batch_size
    prefetcher = DataPrefetcher(train_loader)
    inputs, labels = prefetcher.next()
    iter_index = 1
    while inputs is not None:
        inputs, labels = inputs.float().cuda(), labels.cuda()

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        stu_outputs, tea_outputs = net(inputs)
        loss = 0
        loss_detail = []
        for i, loss_item in enumerate(Config.loss_list):
            loss_type = loss_item["loss_type"]
            if loss_type == "ce_family":
                tmp_loss = loss_alphas[i] * criterion[i](stu_outputs[-1],
                                                         labels)
            elif loss_type == "kd_family":
                tmp_loss = loss_alphas[i] * criterion[i](stu_outputs[-1],
                                                         tea_outputs[-1])
            elif loss_type == "gkd_family":
                tmp_loss = loss_alphas[i] * criterion[i](
                    stu_outputs[-1], tea_outputs[-1], labels)
            elif loss_type == "fd_family":
                tmp_loss = loss_alphas[i] * criterion[i](stu_outputs[:-1],
                                                         tea_outputs[:-1])

            loss_detail.append(tmp_loss.item())
            loss_ams[i].update(tmp_loss.item(), inputs.size(0))
            loss += tmp_loss

        loss = loss / args.accumulation_steps

        if args.apex:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        if iter_index % args.accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        acc1, acc5 = accuracy(stu_outputs[-1], labels, topk=(1, 5))
        top1.update(acc1.item(), inputs.size(0))
        top5.update(acc5.item(), inputs.size(0))
        loss_total.update(loss.item(), inputs.size(0))

        inputs, labels = prefetcher.next()

        loss_log = ""
        if iter_index % args.print_interval == 0:
            loss_log += f"train: epoch {epoch:0>3d}, iter [{iter_index:0>4d}, {iters:0>4d}], lr: {scheduler.get_lr()[0]:.6f}, top1 acc: {acc1.item():.2f}%, top5 acc: {acc5.item():.2f}%, loss_total: {loss.item():.2f}, "

            for i, loss_item in enumerate(Config.loss_list):
                loss_name = loss_item["loss_name"]
                loss_log += f"{loss_name}: {loss_detail[i]:2f}, alpha: {loss_alphas[i]:2f}, "

            logger.info(loss_log)

        iter_index += 1

    scheduler.step()

    return top1.avg, top5.avg, loss_total.avg