def train(train_dataloader, val_dataloader, optimizer, scheduler, model, archloader, criterion, args, seed, epoch, writer=None):
    losses_, top1_, top5_ = AvgrageMeter(), AvgrageMeter(), AvgrageMeter()

    # for p in model.parameters():
    #     p.grad = torch.zeros_like(p)
    model.train()

    train_loader = tqdm(train_dataloader)
    train_loader.set_description(
        '[%s%04d/%04d %s%f]' % ('Epoch:', epoch + 1, args.epochs, 'lr:', scheduler.get_last_lr()[0]))
    for step, (image, target) in enumerate(train_loader):
        n = image.size(0)
        image = Variable(image, requires_grad=False).cuda(
            args.gpu, non_blocking=True)
        target = Variable(target, requires_grad=False).cuda(
            args.gpu, non_blocking=True)

        # Fair Sampling
        # [archloader.generate_niu_fair_batch(step)[-1]]
        # [16, 16, 16, 16, 16, 16, 16, 32, 32, 32, 32, 32, 32, 64, 64, 64, 64, 64, 64, 64]
        spos_arc_list = archloader.generate_spos_like_batch().tolist()

        # for arc in fair_arc_list:
        # logits = model(image, archloader.convert_list_arc_str(arc))
        # loss = criterion(logits, target)
        # loss_reduce = reduce_tensor(loss, 0, args.world_size)
        # loss.backward()
        optimizer.zero_grad()
        logits = model(image, spos_arc_list[:-1])
        loss = criterion(logits, target)
        prec1, prec5 = accuracy(logits, target, topk=(1, 5))

        if torch.cuda.device_count() > 1:
            torch.distributed.barrier()

            loss = reduce_mean(loss, args.nprocs)
            prec1 = reduce_mean(prec1, args.nprocs)
            prec5 = reduce_mean(prec5, args.nprocs)

        loss.backward()

        # nn.utils.clip_grad_value_(model.parameters(), args.grad_clip)

        optimizer.step()

        losses_.update(loss.data.item(), n)
        top1_.update(prec1.data.item(), n)
        top5_.update(prec1.data.item(), n)

        postfix = {'train_loss': '%.6f' % (
            losses_.avg), 'train_acc1': '%.6f' % top1_.avg, 'train_acc5': '%.6f' % top5_.avg}
        train_loader.set_postfix(log=postfix)

        if args.local_rank == 0 and step % 10 == 0 and writer is not None:
            writer.add_scalar("Train/loss", losses_.avg, step +
                              len(train_dataloader) * epoch * args.batch_size)
            writer.add_scalar("Train/acc1", top1_.avg, step +
                              len(train_dataloader) * epoch * args.batch_size)
            writer.add_scalar("Train/acc5", top5_.avg, step +
                              len(train_loader)*args.batch_size*epoch)
def infer(train_loader, val_loader, model, criterion, val_iters, archloader,
          args):

    objs_, top1_, top5_ = AvgrageMeter(), AvgrageMeter(), AvgrageMeter()

    model.eval()
    now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))

    # [16, 16, 16, 16, 16, 16, 16, 32, 32, 32, 32, 32, 32, 64, 64, 64, 64, 64, 64, 64]
    fair_arc_list = archloader.generate_niu_fair_batch(random.randint(
        0, 100))[-1].tolist()
    # archloader.generate_spos_like_batch().tolist()

    print('{} |=> Test rng = {}'.format(now, fair_arc_list))  # 只测试最后一个模型

    # BN calibration
    # retrain_bn(model, 15, train_loader, fair_arc_list, device=0)

    with torch.no_grad():
        for step, (image, target) in enumerate(val_loader):
            t0 = time.time()
            datatime = time.time() - t0
            image = Variable(image,
                             requires_grad=False).cuda(args.local_rank,
                                                       non_blocking=True)
            target = Variable(target,
                              requires_grad=False).cuda(args.local_rank,
                                                        non_blocking=True)

            logits = model(image)  # , fair_arc_list)
            loss = criterion(logits, target)

            top1, top5 = accuracy(logits, target, topk=(1, 5))

            if torch.cuda.device_count() > 1:
                torch.distributed.barrier()

                loss = reduce_mean(loss, args.nprocs)
                top1 = reduce_mean(top1, image.size(0))
                top5 = reduce_mean(top5, image.size(0))

            n = image.size(0)
            objs_.update(loss.data.item(), n)
            top1_.update(top1.data.item(), n)
            top5_.update(top5.data.item(), n)

        now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
        print(
            '{} |=> valid: step={}, loss={:.2f}, val_acc1={:.2f}, val_acc5={:2f}, datatime={:.2f}'
            .format(now, step, objs_.avg, top1_.avg, top5_.avg, datatime))

    return top1_.avg, top5_.avg, objs_.avg
示例#3
0
def train(train_dataloader,
          val_dataloader,
          optimizer,
          scheduler,
          model,
          archloader,
          criterion,
          soft_criterion,
          args,
          seed,
          epoch,
          writer=None):
    losses_, top1_, top5_ = AvgrageMeter(), AvgrageMeter(), AvgrageMeter()

    model.train()
    widest = [
        16, 16, 16, 16, 16, 16, 16, 32, 32, 32, 32, 32, 32, 64, 64, 64, 64, 64,
        64, 64
    ]
    narrowest = [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]

    train_loader = tqdm(train_dataloader)
    train_loader.set_description(
        '[%s%04d/%04d %s%f]' %
        ('Epoch:', epoch + 1, args.epochs, 'lr:', scheduler.get_last_lr()[0]))
    for step, (image, target) in enumerate(train_loader):
        n = image.size(0)
        image = Variable(image, requires_grad=False).cuda(args.gpu,
                                                          non_blocking=True)
        target = Variable(target, requires_grad=False).cuda(args.gpu,
                                                            non_blocking=True)

        if args.model_type in ["dynamic", "independent", "slimmable"]:
            # sandwich rule
            candidate_list = []
            candidate_list += [narrowest]
            candidate_list += [
                archloader.generate_spos_like_batch().tolist()
                for i in range(6)
            ]

            # archloader.generate_niu_fair_batch(step)
            # 全模型来一遍
            soft_target = model(image, widest)
            soft_loss = criterion(soft_target, target)
            soft_loss.backward()
            soft_target = torch.nn.functional.softmax(soft_target,
                                                      dim=1).detach()

            # 采样几个子网来一遍
            for arc in candidate_list:
                logits = model(image, arc)
                # loss = soft_criterion(logits, soft_target.cuda(
                #     args.gpu, non_blocking=True))
                loss = criterion(logits, target)

                # loss_reduce = reduce_tensor(loss, 0, args.world_size)
                loss.backward()
        elif args.model_type == "original":
            logits = model(image)
            loss = criterion(logits, target)
            loss.backward()

        prec1, prec5 = accuracy(logits, target, topk=(1, 5))

        if torch.cuda.device_count() > 1:
            torch.distributed.barrier()

            loss = reduce_mean(loss, args.nprocs)
            prec1 = reduce_mean(prec1, args.nprocs)
            prec5 = reduce_mean(prec5, args.nprocs)

        optimizer.step()
        optimizer.zero_grad()

        losses_.update(loss.data.item(), n)
        top1_.update(prec1.data.item(), n)
        top5_.update(prec1.data.item(), n)

        postfix = {
            'train_loss': '%.6f' % (losses_.avg),
            'train_acc1': '%.6f' % top1_.avg,
            'train_acc5': '%.6f' % top5_.avg
        }

        train_loader.set_postfix(log=postfix)

        if args.local_rank == 0 and step % 10 == 0 and writer is not None:
            writer.add_scalar(
                "Train/loss", losses_.avg,
                step + len(train_dataloader) * epoch * args.batch_size)
            writer.add_scalar(
                "Train/acc1", top1_.avg,
                step + len(train_dataloader) * epoch * args.batch_size)
            writer.add_scalar(
                "Train/acc5", top5_.avg,
                step + len(train_loader) * args.batch_size * epoch)