Example #1
0
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            # progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
            #              % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100. * correct / total
    logging.info('the test acc is : {}'.format(acc))
    if acc > best_acc:
        logging.info('Saving..')
        if args.quan_mode == 'Conv2dDPQ':
            flops, params = get_model_complexity_info(net, (3, 32, 32))
            logging.info(
                'the total flops of {} is : {} and whole params is : {}'.
                format(args.arch, flops, params))
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(
            state,
            './checkpoint/{}-{}-ckpt.pth'.format(args.arch,
                                                 args.quan_mode[-3:]))
        best_acc = acc
Example #2
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('-n',
                        '--net',
                        type=str,
                        default='mobilenetv2',
                        help='net type')
    parser.add_argument('-g',
                        '--gpu',
                        action='store_true',
                        default=False,
                        help='use gpu or not')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=256,
                        help='batch size for dataloader')
    parser.add_argument('--warm',
                        type=int,
                        default=5,
                        help='warm up training phase')
    parser.add_argument('-lr',
                        '--learning-rate',
                        type=float,
                        default=0.1,
                        help='initial learning rate')
    parser.add_argument('--resume',
                        action='store_true',
                        default=False,
                        help='resume training')
    parser.add_argument('-j',
                        '--workers',
                        type=int,
                        default=4,
                        help='the process number')
    parser.add_argument('-p',
                        '--print-freq',
                        default=50,
                        type=int,
                        metavar='N',
                        help='print frequency (default: 10)')
    parser.add_argument('-e',
                        '--evaluate',
                        dest='evaluate',
                        action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('--epochs',
                        default=200,
                        type=int,
                        help='training epochs')
    parser.add_argument('--milestones',
                        default=[60, 120, 180],
                        nargs='+',
                        type=int,
                        help='milestones of MultiStepLR')
    parser.add_argument('--checkpoint', type=str, default='checkpoint')
    parser.add_argument('--log-dir', default='logger', type=str)
    parser.add_argument('--save',
                        default='EXP',
                        type=str,
                        help='save for the tensor log')
    parser.add_argument(
        '--save_model',
        type=str,
        default='cifar100-mobilenetv3_large-models/model_best.pth.tar')
    args = parser.parse_args()

    net = get_network(args)

    args.save = 'cifar100-{}-{}-{}'.format(args.net, args.save,
                                           time.strftime("%Y%m%d-%H%M%S"))
    log_dir = '{}/{}'.format('logger', args.save)
    writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.save))
    input_tensor = torch.Tensor(1, 3, 32, 32).cuda()
    writer.add_graph(net, input_tensor)

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')

    fh = logging.FileHandler(os.path.join('{}/log.txt'.format(log_dir)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    logging.getLogger().setLevel(logging.INFO)

    flops, params = get_model_complexity_info(net, (3, 32, 32),
                                              print_per_layer_stat=False)
    logging.info(
        'the model for train(fp) flops is {} and its params is {} '.format(
            flops, params))

    CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095,
                           0.4409178433670343)
    CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883,
                          0.27615047132568404)
    #data preprocessing:
    cifar100_training_loader = get_training_dataloader(
        CIFAR100_TRAIN_MEAN,
        CIFAR100_TRAIN_STD,
        num_workers=args.workers,
        batch_size=args.batch_size,
        shuffle=True)

    cifar100_test_loader = get_test_dataloader(CIFAR100_TRAIN_MEAN,
                                               CIFAR100_TRAIN_STD,
                                               num_workers=args.workers,
                                               batch_size=args.batch_size,
                                               shuffle=True)

    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(),
                          lr=args.learning_rate,
                          momentum=0.9,
                          weight_decay=5e-4)
    train_scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.milestones, gamma=0.2)  #learning rate decay
    iter_per_epoch = len(cifar100_training_loader)
    warmup_scheduler = WarmUpLR(optimizer, iter_per_epoch * args.warm)

    #use tensorboard
    if not os.path.exists(args.log_dir):
        os.mkdir(args.log_dir)

    #since tensorboard can't overwrite old values
    #so the only way is to create a new tensorboard log

    best_acc = 0.0
    args.start_epoch = 0
    if args.resume:
        if os.path.isfile('{}/{}'.format(args.log_dir, args.save_model)):
            logging.info("=> loading checkpoint '{}/{}'".format(
                args.log_dir, args.save_model))
            checkpoint = torch.load(os.path.join(args.log_dir,
                                                 args.save_model))
            logging.info('load best training file to test acc...')
            net.load_state_dict(checkpoint['net'])
            logging.info('best acc is {:0.2f}'.format(checkpoint['acc']))
            best_acc = checkpoint['acc']
            args.start_epoch = checkpoint['epoch']
        else:
            logging.info("=> no checkpoint found at '{}/{}'".format(
                args.log_dir, args.save_model))
            raise Exception('No such model saved !')

    if args.evaluate:
        acc = eval_training(net, cifar100_test_loader, args, writer,
                            loss_function, 0)
        logging.info('the final best acc is {}'.format(best_acc))
        return

    for epoch in range(args.start_epoch, args.epochs):
        train(net, cifar100_training_loader, args, optimizer, epoch, writer,
              warmup_scheduler, loss_function)
        acc = eval_training(net, cifar100_test_loader, args, writer,
                            loss_function, epoch)
        if epoch > args.warm:
            train_scheduler.step(epoch)

        is_best = False
        if acc > best_acc:
            best_acc = acc
            logging.info('the best acc is {} in epoch {}'.format(
                best_acc, epoch))
            is_best = True
        save_checkpoint(
            {
                'epoch': epoch,
                'net': net.state_dict(),
                'acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            save='logger/{}-{}-models'.format('cifar100', args.net))

    logging.info('the final best acc is {}'.format(best_acc))
    writer.close()
Example #3
0
def main():
    args = parser.parse_args()

    torch.manual_seed(args.manual_seed)
    torch.cuda.manual_seed_all(args.manual_seed)
    np.random.seed(args.manual_seed)
    random.seed(args.manual_seed)  # 设置随机种子

    args.save = 'cifar-train-{}-{}-{}'.format(args.arch, args.save,
                                              time.strftime("%Y%m%d-%H%M%S"))

    from tensorboardX import SummaryWriter
    writer_comment = args.save
    log_dir = '{}/{}'.format('logger', args.save)
    writer = SummaryWriter(log_dir=log_dir, comment=writer_comment)

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')

    fh = logging.FileHandler(os.path.join('{}/log.txt'.format(log_dir)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    best_acc = 0  # best test accuracy
    start_epoch = 0  # start from epoch 0 or last checkpoint epoch

    # Data
    print('==> Preparing data..')
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data',
                                            train=True,
                                            download=True,
                                            transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=256,
                                              shuffle=True,
                                              num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data',
                                           train=False,
                                           download=True,
                                           transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=256,
                                             shuffle=False,
                                             num_workers=2)

    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    # Model
    logging.info('==> Building model..')
    # net = VGG('VGG19')
    # net = ResNet18()
    # net = PreActResNet18()
    # net = GoogLeNet()
    # net = DenseNet121()
    # net = ResNeXt29_2x64d()
    # net = MobileNet()
    # net = MobileNetV2()
    if args.arch in models.__dict__:
        net = models.__dict__[args.arch]()
    else:
        net = BModels.__dict__[args.arch]()

    # net = DPN92()
    # net = ShuffleNetG2()
    # net = SENet18()
    # net = ShuffleNetV2(1)
    # net = EfficientNetB0()
    # net = RegNetX_200MF()
    # net = SimpleDLA()
    net = net.to(device)
    flops, params = get_model_complexity_info(net, (3, 32, 32),
                                              print_per_layer_stat=False)
    logging.info(
        'the total flops of {} is : {} and whole params is : {}'.format(
            args.arch, flops, params))
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    if args.resume:
        # Load checkpoint.
        logging.info('==> Resuming from checkpoint..')
        assert os.path.isdir(
            'checkpoint'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('./checkpoint/{}-ckpt.pth'.format(args.arch))
        net.load_state_dict(checkpoint['net'])
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch']

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           T_max=200)

    if args.evaluate:
        validate(testloader, net, criterion, args)
        return

    best_acc = 0.0
    for epoch in range(start_epoch, start_epoch + 200):
        # train(epoch)
        train_loss, train_top1 = train(trainloader, net, criterion, optimizer,
                                       epoch, args)
        validate_loss, validate_top1 = validate(testloader, net, criterion,
                                                args)
        logging.info(
            'the train loss is : {} ; For Train !  the top1 accuracy is : {} '.
            format(train_loss, train_top1))
        logging.info(
            'the validate loss is : {} ; For Validate !  the top1 accuracy is : {} '
            .format(validate_loss, validate_top1))
        writer.add_scalars('Train-Loss/Training-Validate', {
            'train_loss': train_loss,
            'validate_loss': validate_loss
        }, epoch + 1)
        writer.add_scalars('Train-Top1/Training-Validate', {
            'train_acc1': train_top1,
            'validate_acc1': validate_top1
        }, epoch + 1)
        writer.add_scalar('Learning-Rate-For-Train',
                          optimizer.state_dict()['param_groups'][0]['lr'],
                          epoch + 1)
        if validate_top1 > best_acc:
            best_acc = validate_top1
            logging.info(
                'the best model top1 is : {} and its epoch is {} !'.format(
                    best_acc, epoch))
            state = {
                'net': net.module.state_dict(),
                'acc': best_acc,
                'epoch': epoch
            }
            if not os.path.isdir('checkpoint'):
                os.mkdir('checkpoint')
            torch.save(state, './checkpoint/{}-ckpt.pth'.format(args.arch))
        scheduler.step()

    logging.info('the final best model top1 is : {} !'.format(best_acc))
Example #4
0
def main():
    q_modes_choice = sorted(['kernel_wise', 'layer_wise'])
    parser = argparse.ArgumentParser()
    parser.add_argument('-n',
                        '--net',
                        type=str,
                        default='mobilenetv2',
                        help='net type')
    parser.add_argument('-g',
                        '--gpu',
                        action='store_true',
                        default=False,
                        help='use gpu or not')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=256,
                        help='batch size for dataloader')
    parser.add_argument('--warm',
                        type=int,
                        default=5,
                        help='warm up training phase')
    parser.add_argument('-lr',
                        '--learning-rate',
                        type=float,
                        default=0.1,
                        help='initial learning rate')
    parser.add_argument('--resume',
                        action='store_true',
                        default=False,
                        help='resume training')
    parser.add_argument('-j',
                        '--workers',
                        type=int,
                        default=4,
                        help='the process number')
    parser.add_argument('-p',
                        '--print-freq',
                        default=50,
                        type=int,
                        metavar='N',
                        help='print frequency (default: 10)')
    parser.add_argument('-e',
                        '--evaluate',
                        dest='evaluate',
                        action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=3e-4,
                        help='weight decay')
    parser.add_argument('--epochs',
                        default=100,
                        type=int,
                        help='training epochs')
    parser.add_argument('--milestones',
                        default=[30, 60, 90],
                        nargs='+',
                        type=int,
                        help='milestones of MultiStepLR')
    parser.add_argument('--gamma', default=0.2, type=float)
    parser.add_argument('--manual-seed',
                        default=2,
                        type=int,
                        help='random seed is settled')
    parser.add_argument('--checkpoint', type=str, default='checkpoint')
    parser.add_argument('--log-dir', default='logger', type=str)
    parser.add_argument('--save',
                        default='EXP',
                        type=str,
                        help='save for the tensor log')
    parser.add_argument("--kd-ratio",
                        type=float,
                        default=0.5,
                        help="learning from soft label distribution")
    parser.add_argument(
        '--save-model',
        type=str,
        default='initial-cifar100-mobilenetv2-HSQ-models/model_best.pth.tar')
    parser.add_argument('--pretrained',
                        default=False,
                        action='store_true',
                        help='load pretrained model')
    parser.add_argument('--quan-mode',
                        type=str,
                        default='Conv2dHSQ',
                        help='corresponding for the quantize conv')
    parser.add_argument('--q-mode',
                        choices=q_modes_choice,
                        default='layer_wise',
                        help='Quantization modes: ' +
                        ' | '.join(q_modes_choice) + ' (default: kernel-wise)')
    parser.add_argument('--initial-epochs',
                        type=int,
                        default=20,
                        help='initial epochs for specific blocks')
    parser.add_argument('--initial-lr',
                        default=2e-2,
                        type=float,
                        help='learning rate for initial blocks')
    args = parser.parse_args()

    torch.manual_seed(args.manual_seed)
    torch.cuda.manual_seed_all(args.manual_seed)
    np.random.seed(args.manual_seed)
    random.seed(args.manual_seed)  # 设置随机种子

    args.save = 'cifar100-feature-{}-{}-{}-{}'.format(
        args.net, args.save, args.quan_mode[-3:],
        time.strftime("%Y%m%d-%H%M%S"))
    writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.save))
    input_tensor = torch.Tensor(1, 3, 32, 32).cuda()

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')

    fh = logging.FileHandler(
        os.path.join('{}/{}/log.txt'.format(args.log_dir, args.save)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    logging.getLogger().setLevel(logging.INFO)

    net = get_network(args)
    if args.pretrained:
        checkpoint = torch.load(
            os.path.join(
                args.log_dir,
                '{}-{}-models/model_best.pth.tar'.format('cifar100',
                                                         args.net)))
        net.load_state_dict(checkpoint['net'])

    teacher_model = copy.deepcopy(net)

    replace_conv_recursively(net, args.quan_mode, args)

    if not args.quan_mode == 'Conv2dDPQ':
        flops, params = get_model_complexity_info(net, (3, 32, 32),
                                                  print_per_layer_stat=False)
        logging.info(
            'the model after quantized flops is {} and its params is {} '.
            format(flops, params))
        writer.add_graph(net, input_tensor)

    logging.info('args = %s', args)

    CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095,
                           0.4409178433670343)
    CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883,
                          0.27615047132568404)
    #data preprocessing:
    cifar100_training_loader = get_training_dataloader(
        CIFAR100_TRAIN_MEAN,
        CIFAR100_TRAIN_STD,
        num_workers=args.workers,
        batch_size=args.batch_size,
        shuffle=True)

    cifar100_test_loader = get_test_dataloader(CIFAR100_TRAIN_MEAN,
                                               CIFAR100_TRAIN_STD,
                                               num_workers=args.workers,
                                               batch_size=args.batch_size,
                                               shuffle=True)

    stage_blocks = []
    stage_blocks.append(['pre', 'stage1', 'stage2'])
    stage_blocks.append(['stage3', 'stage4'])
    stage_blocks.append(['stage5', 'stage6'])
    stage_blocks.append(['stage7', 'conv1'])

    optimizers, schedulers = [], []
    for block in stage_blocks:
        params = add_weight_decay(net,
                                  weight_decay=args.weight_decay,
                                  skip_keys=['expand_', 'running_scale'],
                                  grads=block)
        optimizer = optim.SGD(params,
                              lr=args.initial_lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
        scheduler = optim.lr_scheduler.LambdaLR(
            optimizer,
            lambda step: (1.0 - step / args.initial_epochs),
            last_epoch=-1)
        optimizers.append(optimizer)
        schedulers.append(scheduler)

    initial_criterion = nn.MSELoss(reduce=True, reduction='mean')

    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    train_scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.milestones,
        gamma=args.gamma)  #learning rate decay
    iter_per_epoch = len(cifar100_training_loader)
    warmup_scheduler = WarmUpLR(optimizer, iter_per_epoch * args.warm)

    #use tensorboard
    if not os.path.exists(args.log_dir):
        os.mkdir(args.log_dir)

    if args.evaluate:
        acc = eval_training(net, cifar100_test_loader, args, writer,
                            loss_function, 0)
        logging.info('the acc for validate currently is {}'.format(acc))
        return

    best_acc = 0.0
    args.initial_start_epoch, args.start_epoch = 0, 0
    if args.resume:
        if os.path.isfile('{}/{}'.format(args.log_dir, args.save_model)):
            logging.info("=> loading checkpoint '{}/{}'".format(
                args.log_dir, args.save_model))
            checkpoint = torch.load(os.path.join(args.log_dir,
                                                 args.save_model))
            logging.info('load best training file to test acc...')
            net.load_state_dict(checkpoint['net'])
            logging.info('best acc is {:0.2f}'.format(checkpoint['acc']))
            best_acc = checkpoint['acc']
            if 'initial' in args.save_model:
                args.initial_start_epoch = checkpoint['epoch']
            else:
                args.initial_start_epoch = args.initial_epochs
                args.start_epoch = checkpoint['epoch']
        else:
            logging.info("=> no checkpoint found at '{}/{}'".format(
                args.log_dir, args.save_model))
            raise Exception('No such model saved !')

    for epoch in range(args.initial_start_epoch, args.initial_epochs):
        feature_train(net, teacher_model, cifar100_training_loader, args,
                      optimizers, epoch, writer, initial_criterion)
        acc = eval_training(net,
                            cifar100_test_loader,
                            args,
                            writer,
                            loss_function,
                            epoch,
                            tb=False)
        for scheduler in schedulers:
            scheduler.step()

        is_best = False
        if acc > best_acc:
            best_acc = acc
            logging.info('the best acc is {} in epoch {}'.format(
                best_acc, epoch))
            is_best = True
        save_checkpoint(
            {
                'epoch': epoch,
                'net': net.state_dict(),
                'acc': best_acc,
            },
            is_best,
            save='logger/initial-{}-{}-{}-models'.format(
                'cifar100', args.net, args.quan_mode[-3:]))

    for epoch in range(args.start_epoch, args.epochs):
        kd_train(net, teacher_model, cifar100_training_loader, args, optimizer,
                 epoch, writer, warmup_scheduler, loss_function)
        acc = eval_training(net, cifar100_test_loader, args, writer,
                            loss_function, epoch)
        if epoch > args.warm:
            train_scheduler.step(epoch)

        is_best = False
        if acc > best_acc:
            best_acc = acc
            logging.info('the best acc is {} in epoch {}'.format(
                best_acc, epoch))
            is_best = True
            if args.quan_mode == 'Conv2dDPQ':
                flops, params = get_model_complexity_info(
                    net, (3, 32, 32), print_per_layer_stat=False)
                logging.info(
                    'the model after quantized flops is {} and its params is {} '
                    .format(flops, params))
        save_checkpoint(
            {
                'epoch': epoch,
                'net': net.state_dict(),
                'acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            save='logger/feature-{}-{}-{}-models'.format(
                'cifar100', args.net, args.quan_mode[-3:]))

    if args.quan_mode == 'Conv2dDPQ':
        flops, params = get_model_complexity_info(net, (3, 32, 32),
                                                  print_per_layer_stat=False)
        logging.info(
            'the model after quantized flops is {} and its params is {} '.
            format(flops, params))
        writer.add_graph(net, input_tensor)

    logging.info('the final best acc is {}'.format(best_acc))
    writer.close()
Example #5
0
net.load_state_dict(checkpoint['net'])

net = replace_conv_recursively(net, args.quan_mode, args)

if args.resume:
    # Load checkpoint.
    print('==> Resuming from Quantized checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/{}-{}-ckpt.pth'.format(
        args.arch, args.quan_mode[-3:]))
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

if args.quan_mode != 'Conv2dDPQ':
    flops, params = get_model_complexity_info(net, (3, 32, 32))
    print('the total flops of mobilenetv2 is : {} and whole params is : {}'.
          format(flops, params))

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),
                      lr=args.lr,
                      momentum=0.9,
                      weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)


# Training
def train(epoch):
    logging.info('\nEpoch: %d' % epoch)
    net.train()
def main():
    # Model
    logging.info('==> Building model {} ..'.format(args.arch))

    net = models.__dict__[args.arch]()

    stage_blocks = []
    block = ['first_']
    for i in range(len(net.layers)):
        if net.layers[i].conv2.stride == 2:
            stage_blocks.append(block)
            block = []
        block.append('layers.{}'.format(i))

    stage_blocks.append(block)
    last_block = ['last_', 'linear']

    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    # Load checkpoint.
    if args.pretrained:
        logging.info('==> Resuming from checkpoint..')
        assert os.path.isdir(
            'checkpoint'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('./checkpoint/{}-ckpt.pth'.format(args.arch))
        net.load_state_dict(checkpoint['net'])

    net = replace_conv_recursively(net, args.quan_mode, args)

    criterion = nn.CrossEntropyLoss()

    if args.evaluate:
        validate_loss, validate_top1 = validate(net,
                                                criterion,
                                                data_loader=testloader,
                                                epoch=0,
                                                valid=True)
        return

    if args.kd_ratio > 0:  # 设置teacher model (supernet),实际上是最大的网络
        args.teacher_model = models.__dict__[args.teacher_arch]()
        args.teacher_model = torch.nn.DataParallel(args.teacher_model)
        checkpoint = torch.load('./checkpoint/{}-ckpt.pth'.format(
            args.teacher_arch))
        args.teacher_model.load_state_dict(checkpoint['net'])

    best_acc1 = 0
    args.initial_start_epoch, args.start_epoch = 0, 0
    if args.resume:
        # Load checkpoint.
        logging.info('==> Resuming from Quantized checkpoint..')
        assert os.path.isdir(
            'checkpoint'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('logger/{}'.format(args.save_model))
        net.load_state_dict(checkpoint['net'])
        best_acc1 = checkpoint['acc']
        if 'initial' in args.save_model:
            args.initial_start_epoch = checkpoint['initial_epoch'] + 1
        else:
            args.initial_start_epoch = args.initial_epochs
            args.start_epoch = checkpoint['epoch'] + 1

    if args.quan_mode != 'Conv2dDPQ':
        flops, params = get_model_complexity_info(net, (3, 32, 32))
        logging.info(
            'the total flops of {} is : {} and whole params is : {}'.format(
                args.arch, flops, params))

    optimizers, schedulers = [], []
    for block in stage_blocks:
        params = add_weight_decay(net,
                                  weight_decay=args.weight_decay,
                                  skip_keys=['expand_', 'running_scale'],
                                  grads=block)
        optimizer = torch.optim.SGD(params,
                                    args.initial_lr,
                                    momentum=args.momentum)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.initial_epochs)
        optimizers.append(optimizer)
        schedulers.append(scheduler)

    params = add_weight_decay(net,
                              weight_decay=args.weight_decay,
                              skip_keys=['expand_', 'running_scale'],
                              grads=last_block)
    final_optimizer = optim.SGD(params,
                                args.initial_lr,
                                momentum=args.momentum)
    final_scheduler = optim.lr_scheduler.CosineAnnealingLR(
        final_optimizer, T_max=args.initial_epochs)

    fin_optimizer = optim.SGD(net.parameters(),
                              lr=args.lr,
                              momentum=0.9,
                              weight_decay=5e-4)
    fin_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(fin_optimizer,
                                                               T_max=200)

    with open('cifar-feature-quan.txt', 'w+') as f:
        for epoch in range(args.initial_start_epoch, args.initial_epochs):
            initial_blocks(net, args.teacher_model, trainloader, optimizers,
                           epoch)
            train_loss, train_top1 = finetune(net, args.teacher_model, args,
                                              trainloader, final_optimizer,
                                              epoch, criterion)
            validate_loss, validate_top1 = validate(net,
                                                    criterion,
                                                    data_loader=testloader,
                                                    epoch=epoch,
                                                    valid=True)
            logging.info(
                'the soft train loss is : {} ; For Train !  the top1 accuracy is : {} ;'
                .format(train_loss, train_top1))
            logging.info(
                'the validate loss is : {} ; For Validate !  the top1 accuracy is : {} ;'
                .format(validate_loss, validate_top1))
            writer.add_scalars('Inital-block-Loss/Training-Validate', {
                'train_soft_loss': train_loss,
                'validate_loss': validate_loss
            }, epoch + 1)
            writer.add_scalars('Inital-block-Top1/Training-Validate', {
                'train_acc1': train_top1,
                'validate_acc1': validate_top1
            }, epoch + 1)
            writer.add_scalars(
                'Learning-Rate-For-Initial', {
                    'basic optimizer':
                    final_optimizer.state_dict()['param_groups'][0]['lr'],
                }, epoch + 1)
            is_best = False
            if validate_top1 > best_acc1:
                best_acc1 = validate_top1
                is_best = True
                logging.info(
                    'the best model top1 is : {} and its epoch is {} !'.format(
                        best_acc1, epoch))
            save_checkpoint(
                {
                    'initial_epoch': epoch,
                    'state_dict': net.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': final_optimizer.state_dict(),
                    'scheduler': final_scheduler.state_dict(),
                },
                is_best,
                save='logger/{}-{}-{}-initial-models'.format(
                    'cifar', args.arch, args.quan_mode[-3:]))
            for scheduler in schedulers:
                scheduler.step()
            final_scheduler.step()
        logging.info(
            'After initial from feature, we can get the final soft train loss and the top1 accuracy'
        )
        for epoch in range(args.start_epoch, args.epochs):
            train_loss, train_top1 = finetune(net, args.teacher_model, args,
                                              trainloader, fin_optimizer,
                                              epoch, criterion)
            validate_loss, validate_top1 = validate(net,
                                                    criterion,
                                                    data_loader=testloader,
                                                    epoch=epoch,
                                                    valid=True)
            logging.info(
                'the train loss is : {} ; For Train !  the top1 accuracy is : {} ;'
                .format(train_loss, train_top1))
            logging.info(
                'the validate loss is : {} ; For Validate !  the top1 accuracy is : {} ;'
                .format(validate_loss, validate_top1))
            writer.add_scalars('Quantization-Loss/Training-Validate', {
                'train_loss': train_loss,
                'validate_loss': validate_loss
            }, epoch + 1)
            writer.add_scalars('Quantization-Top1/Training-Validate', {
                'train_acc1': train_top1,
                'validate_acc1': validate_top1
            }, epoch + 1)
            writer.add_scalars(
                'Learning-Rate-For-Finetune', {
                    'basic optimizer':
                    fin_optimizer.state_dict()['param_groups'][0]['lr'],
                }, epoch + 1)
            is_best = False
            if validate_top1 > best_acc1:
                best_acc1 = validate_top1
                is_best = True
                logging.info(
                    'the best model top1 is : {} and its epoch is {} !'.format(
                        best_acc1, epoch))
            save_checkpoint(
                {
                    'epoch': epoch,
                    'state_dict': net.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': fin_optimizer.state_dict(),
                    'scheduler': fin_scheduler.state_dict(),
                },
                is_best,
                save='logger/{}-{}-{}-models'.format('cifar', args.arch,
                                                     args.quan_mode[-3:]))
            fin_scheduler.step()
Example #7
0
def main():
    q_modes_choice = sorted(['kernel_wise', 'layer_wise'])
    parser = argparse.ArgumentParser()
    parser.add_argument('-n', '--net', type=str, default='mobilenetv2', help='net type')
    parser.add_argument('-g', '--gpu', action='store_true', default=False, help='use gpu or not')
    parser.add_argument('-b','--batch-size', type=int, default=256, help='batch size for dataloader')
    parser.add_argument('--warm', type=int, default=5, help='warm up training phase')
    parser.add_argument('-lr','--learning-rate', type=float, default=0.1, help='initial learning rate')
    parser.add_argument('--resume', action='store_true', default=False, help='resume training')
    parser.add_argument('-j','--workers', type=int, default=4, help='the process number')
    parser.add_argument('-p', '--print-freq', default=50, type=int,
                    metavar='N', help='print frequency (default: 10)')
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
    parser.add_argument('--epochs', default=50, type=int, help='training epochs')
    parser.add_argument('--alter-epoch', default=10, type=int, help='alternating epoch for evolution')
    parser.add_argument('--milestones', default=[15, 30, 40], nargs='+', type=int,
                    help='milestones of MultiStepLR')
    parser.add_argument('--manual-seed', default=2, type=int, help='random seed is settled')
    parser.add_argument('--checkpoint', type=str, default='checkpoint')
    parser.add_argument('--log-dir', default='logger', type=str)
    parser.add_argument('--save', default='EXP', type=str, help='save for the tensor log')
    parser.add_argument('--save-model', type=str, default='cifar100-mobilenetv2-HSQ-models/model_best.pth.tar')
    parser.add_argument('--pretrained', default=False, action='store_true', help='load pretrained model')
    parser.add_argument('--quan-mode', type=str, default='Conv2dDPQ', 
                    help='corresponding for the quantize conv')
    parser.add_argument('--q-mode', choices=q_modes_choice, default='layer_wise',
                    help='Quantization modes: ' + ' | '.join(q_modes_choice) +
                            ' (default: kernel-wise)')
    args = parser.parse_args()

    torch.manual_seed(args.manual_seed)
    torch.cuda.manual_seed_all(args.manual_seed)
    np.random.seed(args.manual_seed)
    random.seed(args.manual_seed)  # 设置随机种子

    args.save = 'ADmix-cifar100-{}-{}-{}-{}'.format(args.net, args.save, args.quan_mode[-3:], time.strftime("%Y%m%d-%H%M%S"))
    writer = SummaryWriter(log_dir=os.path.join(
            args.log_dir, args.save))
    input_tensor = torch.Tensor(1, 3, 32, 32).cuda()

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
        format=log_format, datefmt='%m/%d %I:%M:%S %p')

    fh = logging.FileHandler(os.path.join('{}/{}/log.txt'.format(args.log_dir, args.save)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    logging.getLogger().setLevel(logging.INFO)

    net = get_network(args)
    if args.pretrained:
        checkpoint = torch.load(os.path.join(args.log_dir, '{}-{}-models/model_best.pth.tar'.format('cifar100', args.net)))
        net.load_state_dict(checkpoint['net'])

    flops, params = get_model_complexity_info(net, (3,32,32), print_per_layer_stat=False)
    logging.info('the original model {} flops is {} and its params is {} '.format(args.net, flops, params))

    replace_conv_recursively(net, args.quan_mode, args)

    logging.info('args = %s', args)
    

    CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
    CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
    #data preprocessing:
    cifar100_training_loader = get_training_dataloader(
        CIFAR100_TRAIN_MEAN,
        CIFAR100_TRAIN_STD,
        num_workers=args.workers,
        batch_size=args.batch_size,
        shuffle=True
    )

    cifar100_test_loader = get_test_dataloader(
        CIFAR100_TRAIN_MEAN,
        CIFAR100_TRAIN_STD,
        num_workers=args.workers,
        batch_size=args.batch_size,
        shuffle=True
    )

    loss_function = nn.CrossEntropyLoss()
    alphaparams, xmaxparams = [], []
    for name, param in net.named_parameters():
        if not param.requires_grad:
            continue
        if 'xmax' in name:
            xmaxparams.append(param)
        elif 'alpha' in name:
            alphaparams.append(param)
        else:
            xmaxparams.append(param)
            alphaparams.append(param)

    alphaparams = [{'params': alphaparams, 'weight_decay': 5e-4}]
    xmaxparams = [{'params': xmaxparams, 'weight_decay': 5e-4}]

    alphaoptimizer = optim.SGD(alphaparams, lr=args.learning_rate, momentum=0.9, weight_decay=5e-4)
    alphascheduler = optim.lr_scheduler.MultiStepLR(alphaoptimizer, milestones=args.milestones, gamma=0.2) #learning rate decay
    iter_per_epoch = len(cifar100_training_loader)
    alphawarmup_scheduler = WarmUpLR(alphaoptimizer, iter_per_epoch * args.warm)
    xmaxoptimizer = optim.SGD(xmaxparams, args.learning_rate, momentum=0.9, weight_decay=5e-4)
    xmaxscheduler = optim.lr_scheduler.MultiStepLR(xmaxoptimizer, milestones=args.milestones, gamma=0.2)
    xmaxwarmup_scheduler = WarmUpLR(xmaxoptimizer, iter_per_epoch * args.warm)

    #use tensorboard
    if not os.path.exists(args.log_dir):
        os.mkdir(args.log_dir)

    #since tensorboard can't overwrite old values
    #so the only way is to create a new tensorboard log

    best_acc = 0.0
    args.iter, args.alpha_start_epoch, args.mix_start_epoch = 0, 0, 0
    if args.resume:
        if os.path.isfile('{}/{}'.format(args.log_dir, args.save_model)):
            logging.info("=> loading checkpoint '{}/{}'".format(args.log_dir, args.save_model))
            checkpoint = torch.load(os.path.join(args.log_dir, args.save_model))
            logging.info('load best training file to test acc...')
            net.load_state_dict(checkpoint['net'])
            logging.info('best acc is {:0.2f}'.format(checkpoint['acc']))
            best_acc = checkpoint['acc']
            if 'alpha' in args.save_model:
                args.alpha_start_epoch = checkpoint['epoch']
                args.mix_start_epoch = args.alter_epoch * (args.alpha_start_epoch // args.alter_epoch)
            else:
                args.mix_start_epoch = checkpoint['epoch']
                args.alpha_start_epoch = args.alter_epoch * (args.mix_start_epoch // args.alter_epoch + 1)
        else:
            logging.info("=> no checkpoint found at '{}/{}'".format(args.log_dir, args.save_model))
            raise Exception('No such model saved !')

    if args.evaluate:
        acc = eval_training(net, cifar100_test_loader, args, writer, loss_function, 0)
        logging.info('the final best acc is {}'.format(best_acc))
        return 

    whole_alter = math.ceil(args.epochs / args.alter_epoch)
    for iter in range(args.iter, whole_alter):
        logging.info('the iter in {}'.format(iter))
        while args.alpha_start_epoch < min((iter+1) * args.alter_epoch, args.epochs):
            train(net, cifar100_training_loader, args, alphaoptimizer, args.alpha_start_epoch, writer, alphawarmup_scheduler, loss_function)
            acc = eval_training(net, cifar100_test_loader, args, writer, loss_function, args.alpha_start_epoch, tb=False)
            if args.alpha_start_epoch > args.warm:
                alphascheduler.step(args.alpha_start_epoch)
            args.alpha_start_epoch += 1

            is_best = False
            if acc > best_acc:
                best_acc = acc
                logging.info('the best acc is {} in epoch {}'.format(best_acc, args.alpha_start_epoch))
                is_best = True
                flops, params = get_model_complexity_info(net, (3,32,32), print_per_layer_stat=False)
                logging.info('the model after quantized flops is {} and its params is {} '.format(flops, params))
            save_checkpoint({
                'epoch': args.alpha_start_epoch,
                'net': net.state_dict(),
                'acc': best_acc,
                'optimizer': alphaoptimizer.state_dict(),
            }, is_best, save='logger/alpha-{}-{}-{}-models'.format('cifar100', args.net, args.quan_mode[-3:]))

        while args.mix_start_epoch < min((iter+1) * args.alter_epoch, args.epochs):
            train(net, cifar100_training_loader, args, xmaxoptimizer, args.mix_start_epoch, writer, xmaxwarmup_scheduler, loss_function)
            acc = eval_training(net, cifar100_test_loader, args, writer, loss_function, args.mix_start_epoch, tb=False)
            if args.mix_start_epoch > args.warm:
                xmaxscheduler.step(args.mix_start_epoch)
            args.mix_start_epoch += 1

            is_best = False
            if acc > best_acc:
                best_acc = acc
                logging.info('the best acc is {} in epoch {}'.format(best_acc, args.mix_start_epoch))
                is_best = True
                flops, params = get_model_complexity_info(net, (3,32,32), print_per_layer_stat=False)
                logging.info('the model after quantized flops is {} and its params is {} '.format(flops, params))
            save_checkpoint({
                'epoch': args.mix_start_epoch,
                'net': net.state_dict(),
                'acc': best_acc,
                'optimizer': xmaxoptimizer.state_dict(),
            }, is_best, save='logger/mix-{}-{}-{}-models'.format('cifar100', args.net, args.quan_mode[-3:]))


    flops, params = get_model_complexity_info(net, (3,32,32), print_per_layer_stat=False)
    logging.info('the model after quantized flops is {} and its params is {} '.format(flops, params))
    writer.add_graph(net, input_tensor)
    logging.info('the final best acc is {}'.format(best_acc))
    writer.close()