Exemple #1
0
def validate(args, valid_loader, model, epoch=0, criterion=False, cur_step=0):
    print(
        '-------------------validation_start at epoch {}---------------------'.
        format(epoch))
    top1 = metrics.AverageMeter()
    top5 = metrics.AverageMeter()
    top10 = metrics.AverageMeter()
    losses = metrics.AverageMeter()

    model.eval()
    model.to(device)
    with torch.no_grad():
        for step, (X, y) in enumerate(valid_loader):
            X, y = X.to(device, non_blocking=True), y.to(device,
                                                         non_blocking=True)
            N = X.size(0)

            if args.distributed:
                if N < int(args.batch_size // world_size):
                    continue
            else:
                if N < args.batch_size:  # skip the last batch
                    continue

            logits = model(X)

            if not args.PCB:
                _, preds = torch.max(logits.data, 1)
                loss = criterion(logits, y)
            else:
                part = {}
                sm = nn.Softmax(dim=1)
                num_part = 6
                for i in range(num_part):
                    part[i] = logits[i]

                score = sm(part[0]) + sm(part[1]) + sm(part[2]) + sm(
                    part[3]) + sm(part[4]) + sm(part[5])
                _, preds = torch.max(score.data, 1)

                loss = criterion(part[0], y)
                for i in range(num_part - 1):
                    loss += criterion(part[i + 1], y)

            if args.PCB:
                prec1, prec5, prec10 = metrics.accuracy(score,
                                                        y,
                                                        topk=(1, 5, 10))
            else:
                prec1, prec5, prec10 = metrics.accuracy(logits,
                                                        y,
                                                        topk=(1, 5, 10))

            if args.distributed:
                dist.simple_sync.allreducemean_list(
                    [loss, prec1, prec5, prec10])

            losses.update(loss.item(), N)
            top1.update(prec1.item(), N)
            top5.update(prec5.item(), N)
            top10.update(prec10.item(), N)

            if args.distributed:
                if rank == 0:
                    if step % args.print_freq == 0 or step == len(
                            valid_loader) - 1:
                        logger.info(
                            "Valid: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                            "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".
                            format(epoch + 1,
                                   args.epochs,
                                   step,
                                   len(valid_loader) - 1,
                                   losses=losses,
                                   top1=top1,
                                   top5=top5))

            else:
                if step % args.print_freq == 0 or step == len(
                        valid_loader) - 1:
                    logger.info(
                        "Valid: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                        "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                            epoch + 1,
                            args.epochs,
                            step,
                            len(valid_loader) - 1,
                            losses=losses,
                            top1=top1,
                            top5=top5))

    if args.distributed:
        if rank == 0:
            writer.add_scalar('val/loss', losses.avg, cur_step)
            writer.add_scalar('val/top1', top1.avg, cur_step)
            writer.add_scalar('val/top5', top5.avg, cur_step)
            writer.add_scalar('val/top10', top10.avg, cur_step)

            logger.info(
                "Valid: [{:2d}/{}] Final Prec@1 {:.4%}, Prec@5 {:.4%}, Prec@10 {:.4%}"
                .format(epoch + 1, args.epochs, top1.avg, top5.avg, top10.avg))

    else:
        writer.add_scalar('val/loss', losses.avg, cur_step)
        writer.add_scalar('val/top1', top1.avg, cur_step)
        writer.add_scalar('val/top5', top5.avg, cur_step)
        writer.add_scalar('val/top10', top10.avg, cur_step)

        logger.info(
            "Valid: [{:2d}/{}] Final Prec@1 {:.4%}, Prec@5 {:.4%}, Prec@10 {:.4%}"
            .format(epoch + 1, args.epochs, top1.avg, top5.avg, top10.avg))

    return top1.avg
def validate(args, valid_loader, model, epoch=0, cur_step=0):
    print(
        '-------------------validation_start at epoch {}---------------------'.
        format(epoch))
    top1 = metrics.AverageMeter()
    top5 = metrics.AverageMeter()
    top10 = metrics.AverageMeter()
    losses = metrics.AverageMeter()

    model.eval()
    model.to(device)
    with torch.no_grad():
        for step, (X, y) in enumerate(valid_loader):
            X, y = X.to(device, non_blocking=True), y.to(device,
                                                         non_blocking=True)
            N = X.size(0)

            ### 必须加分布式判断,否则validation跳过一直为真。
            if args.distributed:
                if N < int(args.batch_size // world_size):
                    continue
            else:
                if N < args.batch_size:  # skip the last batch
                    continue

            logits = model(X)
            loss = model.criterion(logits, y)

            prec1, prec5, prec10 = metrics.accuracy(logits, y, topk=(1, 5, 10))

            if args.distributed:
                dist.simple_sync.allreducemean_list(
                    [loss, prec1, prec5, prec10])

            losses.update(loss.item(), N)
            top1.update(prec1.item(), N)
            top5.update(prec5.item(), N)
            top10.update(prec10.item(), N)

            if args.distributed:
                if rank == 0:
                    if step % args.print_freq == 0 or step == len(
                            valid_loader) - 1:
                        logger.info(
                            "Valid: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                            "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".
                            format(epoch + 1,
                                   args.epochs,
                                   step,
                                   len(valid_loader) - 1,
                                   losses=losses,
                                   top1=top1,
                                   top5=top5))

            else:
                if step % args.print_freq == 0 or step == len(
                        valid_loader) - 1:
                    logger.info(
                        "Valid: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                        "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                            epoch + 1,
                            args.epochs,
                            step,
                            len(valid_loader) - 1,
                            losses=losses,
                            top1=top1,
                            top5=top5))

    if args.distributed:
        if rank == 0:
            writer.add_scalar('val/loss', losses.avg, cur_step)
            writer.add_scalar('val/top1', top1.avg, cur_step)
            writer.add_scalar('val/top5', top5.avg, cur_step)
            writer.add_scalar('val/top10', top10.avg, cur_step)

            logger.info(
                "Valid: [{:2d}/{}] Final Prec@1 {:.4%}, Prec@5 {:.4%}, Prec@10 {:.4%}"
                .format(epoch + 1, args.epochs, top1.avg, top5.avg, top10.avg))
    else:
        writer.add_scalar('val/loss', losses.avg, cur_step)
        writer.add_scalar('val/top1', top1.avg, cur_step)
        writer.add_scalar('val/top5', top5.avg, cur_step)
        writer.add_scalar('val/top10', top10.avg, cur_step)

        logger.info(
            "Valid: [{:2d}/{}] Final Prec@1 {:.4%}, Prec@5 {:.4%}, Prec@10 {:.4%}"
            .format(epoch + 1, args.epochs, top1.avg, top5.avg, top10.avg))

    return top1.avg
Exemple #3
0
def train(args,
          train_loader,
          valid_loader,
          model,
          woptimizer,
          lr_scheduler,
          epoch=0,
          criterion=False):
    print('-------------------training_start at epoch {}---------------------'.
          format(epoch))
    top1 = metrics.AverageMeter()
    top5 = metrics.AverageMeter()
    top10 = metrics.AverageMeter()
    losses = metrics.AverageMeter()

    cur_step = epoch * len(train_loader)

    lr_scheduler.step()
    lr = lr_scheduler.get_lr()[0]

    if args.distributed:
        if rank == 0:
            writer.add_scalar('train/lr', lr, cur_step)
    else:
        writer.add_scalar('train/lr', lr, cur_step)

    model.train()

    running_loss = 0.0
    running_corrects = 0.0
    step = 0

    for samples, labels in train_loader:
        step = step + 1
        now_batch_size, c, h, w = samples.shape
        if now_batch_size < args.batch_size:  # skip the last batch
            continue

        if use_gpu:
            #samples = Variable(samples.cuda().detach())
            #labels = Variable(labels.cuda().detach())
            samples, labels = samples.to(device), labels.to(device)
        else:
            samples, labels = Variable(samples), Variable(labels)

        model.to(device)
        woptimizer.zero_grad()
        logits = model(samples)

        if not args.PCB:
            _, preds = torch.max(logits.data, 1)
            loss = criterion(logits, labels)
        else:
            part = {}
            sm = nn.Softmax(dim=1)
            num_part = 6
            for i in range(num_part):
                part[i] = logits[i]

            score = sm(part[0]) + sm(part[1]) + sm(part[2]) + sm(part[3]) + sm(
                part[4]) + sm(part[5])
            _, preds = torch.max(score.data, 1)

            loss = criterion(part[0], labels)
            for i in range(num_part - 1):
                loss += criterion(part[i + 1], labels)

        if epoch < args.warm_epoch and args.warm_up:
            warm_iteration = round(
                len(train_loader) /
                args.batch_size) * args.warm_epoch  # first 5 epoch
            warm_up = min(1.0, warm_up + 0.9 / warm_iteration)
            loss *= warm_up

        if args.fp16:  # we use optimier to backward loss
            with amp.scale_loss(loss, woptimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        if args.w_grad_clip != False:
            nn.utils.clip_grad_norm_(model.weights(), args.w_grad_clip)

        if args.distributed:
            dist.simple_sync.sync_grad_sum(model)

        woptimizer.step()

        if args.distributed:
            dist.simple_sync.sync_bn_stat(model)

        if args.PCB:
            prec1, prec5, prec10 = metrics.accuracy(score,
                                                    labels,
                                                    topk=(1, 5, 10))
        else:
            prec1, prec5, prec10 = metrics.accuracy(logits,
                                                    labels,
                                                    topk=(1, 5, 10))

        if args.distributed:
            dist.simple_sync.allreducemean_list([loss, prec1, prec5, prec10])

        losses.update(loss.item(), samples.size(0))
        top1.update(prec1.item(), samples.size(0))
        top5.update(prec5.item(), samples.size(0))
        top10.update(prec10.item(), samples.size(0))

        running_loss += loss.item() * now_batch_size

        #y_loss['train'].append(losses)
        #y_err['train'].append(1.0-top1)

        if args.distributed:
            if rank == 0:
                if step % args.print_freq == 0 or step == len(
                        train_loader) - 1:
                    logger.info(
                        "Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                        "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                            epoch + 1,
                            args.epochs,
                            step,
                            len(train_loader) - 1,
                            losses=losses,
                            top1=top1,
                            top5=top5))

                writer.add_scalar('train/loss', loss.item(), cur_step)
                writer.add_scalar('train/top1', prec1.item(), cur_step)
                writer.add_scalar('train/top5', prec5.item(), cur_step)
        else:
            if step % args.print_freq == 0 or step == len(train_loader) - 1:
                logger.info(
                    "Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                    "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                        epoch + 1,
                        args.epochs,
                        step,
                        len(train_loader) - 1,
                        losses=losses,
                        top1=top1,
                        top5=top5))

            writer.add_scalar('train/loss', loss.item(), cur_step)
            writer.add_scalar('train/top1', prec1.item(), cur_step)
            writer.add_scalar('train/top5', prec5.item(), cur_step)
            writer.add_scalar('train/top10', prec10.item(), cur_step)

        cur_step += 1
    if args.distributed:
        if rank == 0:
            logger.info("Train: [{:2d}/{}] Final Prec@1 {:.4%}".format(
                epoch + 1, args.epochs, top1.avg))
    else:
        logger.info("Train: [{:2d}/{}] Final Prec@1 {:.4%}".format(
            epoch + 1, args.epochs, top1.avg))

    if args.distributed:
        if rank == 0:
            if epoch % args.forcesave == 0:
                save_network(args, model, epoch, top1)
    else:
        if epoch % args.forcesave == 0:
            save_network(args, model, epoch, top1)
def train(args,
          train_loader,
          valid_loader,
          model,
          architect,
          w_optim,
          alpha_optim,
          lr_scheduler,
          epoch=0):
    print('-------------------training_start at epoch {}---------------------'.
          format(epoch))
    top1 = metrics.AverageMeter()
    top5 = metrics.AverageMeter()
    top10 = metrics.AverageMeter()
    losses = metrics.AverageMeter()

    cur_step = epoch * len(train_loader)

    lr_scheduler.step()
    lr = lr_scheduler.get_lr()[0]

    if args.distributed:
        if rank == 0:
            writer.add_scalar('train/lr', lr, cur_step)
    else:
        writer.add_scalar('train/lr', lr, cur_step)

    model.train()

    running_loss = 0.0
    running_corrects = 0.0
    #step = 0
    model.to(device)

    for step, ((trn_X, trn_y),
               (val_X, val_y)) in enumerate(zip(train_loader, valid_loader)):
        #step = step+1
        now_batch_size, c, h, w = trn_X.shape
        trn_X, trn_y = trn_X.to(device,
                                non_blocking=True), trn_y.to(device,
                                                             non_blocking=True)
        val_X, val_y = val_X.to(device,
                                non_blocking=True), val_y.to(device,
                                                             non_blocking=True)

        if args.distributed:
            if now_batch_size < int(args.batch_size // world_size):
                continue
        else:
            if now_batch_size < args.batch_size:  # skip the last batch
                continue

        alpha_optim.zero_grad()
        architect.unrolled_backward(trn_X, trn_y, val_X, val_y, lr, w_optim)
        alpha_optim.step()

        w_optim.zero_grad()
        logits = model(trn_X)
        loss = model.criterion(logits, trn_y)
        loss.backward()

        # gradient clipping\
        if args.w_grad_clip != False:
            nn.utils.clip_grad_norm_(model.weights(), args.w_grad_clip)

        if args.distributed:
            if args.sync_grad_sum:
                dist.sync_grad_sum(model)
            else:
                dist.sync_grad_mean(model)

        w_optim.step()

        if args.distributed:
            dist.sync_bn_stat(model)

        prec1, prec5, prec10 = metrics.accuracy(logits, trn_y, topk=(1, 5, 10))

        if args.distributed:
            dist.simple_sync.allreducemean_list([loss, prec1, prec5, prec10])

        losses.update(loss.item(), now_batch_size)
        top1.update(prec1.item(), now_batch_size)
        top5.update(prec5.item(), now_batch_size)
        top10.update(prec10.item(), now_batch_size)

        #running_loss += loss.item() * now_batch_size

        #y_loss['train'].append(losses)
        #y_err['train'].append(1.0-top1)

        if args.distributed:
            if rank == 0:
                if step % args.print_freq == 0 or step == len(
                        train_loader) - 1:
                    logger.info(
                        "Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                        "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                            epoch + 1,
                            args.epochs,
                            step,
                            len(train_loader) - 1,
                            losses=losses,
                            top1=top1,
                            top5=top5))

                writer.add_scalar('train/loss', loss.item(), cur_step)
                writer.add_scalar('train/top1', prec1.item(), cur_step)
                writer.add_scalar('train/top5', prec5.item(), cur_step)
                writer.add_scalar('train/top10', prec10.item(), cur_step)
        else:
            if step % args.print_freq == 0 or step == len(train_loader) - 1:
                logger.info(
                    "Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                    "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                        epoch + 1,
                        args.epochs,
                        step,
                        len(train_loader) - 1,
                        losses=losses,
                        top1=top1,
                        top5=top5))

            writer.add_scalar('train/loss', loss.item(), cur_step)
            writer.add_scalar('train/top1', prec1.item(), cur_step)
            writer.add_scalar('train/top5', prec5.item(), cur_step)
            writer.add_scalar('train/top10', prec10.item(), cur_step)

        cur_step += 1
    if args.distributed:
        if rank == 0:
            logger.info("Train: [{:2d}/{}] Final Prec@1 {:.4%}".format(
                epoch + 1, args.epochs, top1.avg))
    else:
        logger.info("Train: [{:2d}/{}] Final Prec@1 {:.4%}".format(
            epoch + 1, args.epochs, top1.avg))

    if epoch % args.forcesave == 0:
        save_network(args, model, epoch, top1)