コード例 #1
0
def main():
    q_modes_choice = sorted(['kernel_wise', 'layer_wise'])
    parser = argparse.ArgumentParser()
    parser.add_argument('-n',
                        '--net',
                        type=str,
                        default='mobilenetv2',
                        help='net type')
    parser.add_argument('-g',
                        '--gpu',
                        action='store_true',
                        default=False,
                        help='use gpu or not')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=256,
                        help='batch size for dataloader')
    parser.add_argument('--warm',
                        type=int,
                        default=5,
                        help='warm up training phase')
    parser.add_argument('-lr',
                        '--learning-rate',
                        type=float,
                        default=0.1,
                        help='initial learning rate')
    parser.add_argument('--resume',
                        action='store_true',
                        default=False,
                        help='resume training')
    parser.add_argument('-j',
                        '--workers',
                        type=int,
                        default=4,
                        help='the process number')
    parser.add_argument('-p',
                        '--print-freq',
                        default=50,
                        type=int,
                        metavar='N',
                        help='print frequency (default: 10)')
    parser.add_argument('-e',
                        '--evaluate',
                        dest='evaluate',
                        action='store_true',
                        help='evaluate model on validation set')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=3e-4,
                        help='weight decay')
    parser.add_argument('--epochs',
                        default=100,
                        type=int,
                        help='training epochs')
    parser.add_argument('--milestones',
                        default=[30, 60, 90],
                        nargs='+',
                        type=int,
                        help='milestones of MultiStepLR')
    parser.add_argument('--gamma', default=0.2, type=float)
    parser.add_argument('--manual-seed',
                        default=2,
                        type=int,
                        help='random seed is settled')
    parser.add_argument('--checkpoint', type=str, default='checkpoint')
    parser.add_argument('--log-dir', default='logger', type=str)
    parser.add_argument('--save',
                        default='EXP',
                        type=str,
                        help='save for the tensor log')
    parser.add_argument("--kd-ratio",
                        type=float,
                        default=0.5,
                        help="learning from soft label distribution")
    parser.add_argument(
        '--save-model',
        type=str,
        default='initial-cifar100-mobilenetv2-HSQ-models/model_best.pth.tar')
    parser.add_argument('--pretrained',
                        default=False,
                        action='store_true',
                        help='load pretrained model')
    parser.add_argument('--quan-mode',
                        type=str,
                        default='Conv2dHSQ',
                        help='corresponding for the quantize conv')
    parser.add_argument('--q-mode',
                        choices=q_modes_choice,
                        default='layer_wise',
                        help='Quantization modes: ' +
                        ' | '.join(q_modes_choice) + ' (default: kernel-wise)')
    parser.add_argument('--initial-epochs',
                        type=int,
                        default=20,
                        help='initial epochs for specific blocks')
    parser.add_argument('--initial-lr',
                        default=2e-2,
                        type=float,
                        help='learning rate for initial blocks')
    args = parser.parse_args()

    torch.manual_seed(args.manual_seed)
    torch.cuda.manual_seed_all(args.manual_seed)
    np.random.seed(args.manual_seed)
    random.seed(args.manual_seed)  # 设置随机种子

    args.save = 'cifar100-feature-{}-{}-{}-{}'.format(
        args.net, args.save, args.quan_mode[-3:],
        time.strftime("%Y%m%d-%H%M%S"))
    writer = SummaryWriter(log_dir=os.path.join(args.log_dir, args.save))
    input_tensor = torch.Tensor(1, 3, 32, 32).cuda()

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')

    fh = logging.FileHandler(
        os.path.join('{}/{}/log.txt'.format(args.log_dir, args.save)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    logging.getLogger().setLevel(logging.INFO)

    net = get_network(args)
    if args.pretrained:
        checkpoint = torch.load(
            os.path.join(
                args.log_dir,
                '{}-{}-models/model_best.pth.tar'.format('cifar100',
                                                         args.net)))
        net.load_state_dict(checkpoint['net'])

    teacher_model = copy.deepcopy(net)

    replace_conv_recursively(net, args.quan_mode, args)

    if not args.quan_mode == 'Conv2dDPQ':
        flops, params = get_model_complexity_info(net, (3, 32, 32),
                                                  print_per_layer_stat=False)
        logging.info(
            'the model after quantized flops is {} and its params is {} '.
            format(flops, params))
        writer.add_graph(net, input_tensor)

    logging.info('args = %s', args)

    CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095,
                           0.4409178433670343)
    CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883,
                          0.27615047132568404)
    #data preprocessing:
    cifar100_training_loader = get_training_dataloader(
        CIFAR100_TRAIN_MEAN,
        CIFAR100_TRAIN_STD,
        num_workers=args.workers,
        batch_size=args.batch_size,
        shuffle=True)

    cifar100_test_loader = get_test_dataloader(CIFAR100_TRAIN_MEAN,
                                               CIFAR100_TRAIN_STD,
                                               num_workers=args.workers,
                                               batch_size=args.batch_size,
                                               shuffle=True)

    stage_blocks = []
    stage_blocks.append(['pre', 'stage1', 'stage2'])
    stage_blocks.append(['stage3', 'stage4'])
    stage_blocks.append(['stage5', 'stage6'])
    stage_blocks.append(['stage7', 'conv1'])

    optimizers, schedulers = [], []
    for block in stage_blocks:
        params = add_weight_decay(net,
                                  weight_decay=args.weight_decay,
                                  skip_keys=['expand_', 'running_scale'],
                                  grads=block)
        optimizer = optim.SGD(params,
                              lr=args.initial_lr,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
        scheduler = optim.lr_scheduler.LambdaLR(
            optimizer,
            lambda step: (1.0 - step / args.initial_epochs),
            last_epoch=-1)
        optimizers.append(optimizer)
        schedulers.append(scheduler)

    initial_criterion = nn.MSELoss(reduce=True, reduction='mean')

    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    train_scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.milestones,
        gamma=args.gamma)  #learning rate decay
    iter_per_epoch = len(cifar100_training_loader)
    warmup_scheduler = WarmUpLR(optimizer, iter_per_epoch * args.warm)

    #use tensorboard
    if not os.path.exists(args.log_dir):
        os.mkdir(args.log_dir)

    if args.evaluate:
        acc = eval_training(net, cifar100_test_loader, args, writer,
                            loss_function, 0)
        logging.info('the acc for validate currently is {}'.format(acc))
        return

    best_acc = 0.0
    args.initial_start_epoch, args.start_epoch = 0, 0
    if args.resume:
        if os.path.isfile('{}/{}'.format(args.log_dir, args.save_model)):
            logging.info("=> loading checkpoint '{}/{}'".format(
                args.log_dir, args.save_model))
            checkpoint = torch.load(os.path.join(args.log_dir,
                                                 args.save_model))
            logging.info('load best training file to test acc...')
            net.load_state_dict(checkpoint['net'])
            logging.info('best acc is {:0.2f}'.format(checkpoint['acc']))
            best_acc = checkpoint['acc']
            if 'initial' in args.save_model:
                args.initial_start_epoch = checkpoint['epoch']
            else:
                args.initial_start_epoch = args.initial_epochs
                args.start_epoch = checkpoint['epoch']
        else:
            logging.info("=> no checkpoint found at '{}/{}'".format(
                args.log_dir, args.save_model))
            raise Exception('No such model saved !')

    for epoch in range(args.initial_start_epoch, args.initial_epochs):
        feature_train(net, teacher_model, cifar100_training_loader, args,
                      optimizers, epoch, writer, initial_criterion)
        acc = eval_training(net,
                            cifar100_test_loader,
                            args,
                            writer,
                            loss_function,
                            epoch,
                            tb=False)
        for scheduler in schedulers:
            scheduler.step()

        is_best = False
        if acc > best_acc:
            best_acc = acc
            logging.info('the best acc is {} in epoch {}'.format(
                best_acc, epoch))
            is_best = True
        save_checkpoint(
            {
                'epoch': epoch,
                'net': net.state_dict(),
                'acc': best_acc,
            },
            is_best,
            save='logger/initial-{}-{}-{}-models'.format(
                'cifar100', args.net, args.quan_mode[-3:]))

    for epoch in range(args.start_epoch, args.epochs):
        kd_train(net, teacher_model, cifar100_training_loader, args, optimizer,
                 epoch, writer, warmup_scheduler, loss_function)
        acc = eval_training(net, cifar100_test_loader, args, writer,
                            loss_function, epoch)
        if epoch > args.warm:
            train_scheduler.step(epoch)

        is_best = False
        if acc > best_acc:
            best_acc = acc
            logging.info('the best acc is {} in epoch {}'.format(
                best_acc, epoch))
            is_best = True
            if args.quan_mode == 'Conv2dDPQ':
                flops, params = get_model_complexity_info(
                    net, (3, 32, 32), print_per_layer_stat=False)
                logging.info(
                    'the model after quantized flops is {} and its params is {} '
                    .format(flops, params))
        save_checkpoint(
            {
                'epoch': epoch,
                'net': net.state_dict(),
                'acc': best_acc,
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            save='logger/feature-{}-{}-{}-models'.format(
                'cifar100', args.net, args.quan_mode[-3:]))

    if args.quan_mode == 'Conv2dDPQ':
        flops, params = get_model_complexity_info(net, (3, 32, 32),
                                                  print_per_layer_stat=False)
        logging.info(
            'the model after quantized flops is {} and its params is {} '.
            format(flops, params))
        writer.add_graph(net, input_tensor)

    logging.info('the final best acc is {}'.format(best_acc))
    writer.close()
コード例 #2
0
# net = ShuffleNetV2(1)
# net = EfficientNetB0()
# net = RegNetX_200MF()
# net = SimpleDLA()
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

# Load checkpoint.
print('==> Resuming from checkpoint..')
assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
checkpoint = torch.load('./checkpoint/{}-ckpt.pth'.format(args.arch))
net.load_state_dict(checkpoint['net'])

net = replace_conv_recursively(net, args.quan_mode, args)

if args.resume:
    # Load checkpoint.
    print('==> Resuming from Quantized checkpoint..')
    assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
    checkpoint = torch.load('./checkpoint/{}-{}-ckpt.pth'.format(
        args.arch, args.quan_mode[-3:]))
    net.load_state_dict(checkpoint['net'])
    best_acc = checkpoint['acc']
    start_epoch = checkpoint['epoch']

if args.quan_mode != 'Conv2dDPQ':
    flops, params = get_model_complexity_info(net, (3, 32, 32))
    print('the total flops of mobilenetv2 is : {} and whole params is : {}'.
          format(flops, params))
コード例 #3
0
ファイル: ad_quan.py プロジェクト: leliyliu/codes-for-papers
def main():
    q_modes_choice = sorted(['kernel_wise', 'layer_wise'])
    parser = argparse.ArgumentParser()
    parser.add_argument('-n', '--net', type=str, default='mobilenetv2', help='net type')
    parser.add_argument('-g', '--gpu', action='store_true', default=False, help='use gpu or not')
    parser.add_argument('-b','--batch-size', type=int, default=256, help='batch size for dataloader')
    parser.add_argument('--warm', type=int, default=5, help='warm up training phase')
    parser.add_argument('-lr','--learning-rate', type=float, default=0.1, help='initial learning rate')
    parser.add_argument('--resume', action='store_true', default=False, help='resume training')
    parser.add_argument('-j','--workers', type=int, default=4, help='the process number')
    parser.add_argument('-p', '--print-freq', default=50, type=int,
                    metavar='N', help='print frequency (default: 10)')
    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
    parser.add_argument('--epochs', default=50, type=int, help='training epochs')
    parser.add_argument('--alter-epoch', default=10, type=int, help='alternating epoch for evolution')
    parser.add_argument('--milestones', default=[15, 30, 40], nargs='+', type=int,
                    help='milestones of MultiStepLR')
    parser.add_argument('--manual-seed', default=2, type=int, help='random seed is settled')
    parser.add_argument('--checkpoint', type=str, default='checkpoint')
    parser.add_argument('--log-dir', default='logger', type=str)
    parser.add_argument('--save', default='EXP', type=str, help='save for the tensor log')
    parser.add_argument('--save-model', type=str, default='cifar100-mobilenetv2-HSQ-models/model_best.pth.tar')
    parser.add_argument('--pretrained', default=False, action='store_true', help='load pretrained model')
    parser.add_argument('--quan-mode', type=str, default='Conv2dDPQ', 
                    help='corresponding for the quantize conv')
    parser.add_argument('--q-mode', choices=q_modes_choice, default='layer_wise',
                    help='Quantization modes: ' + ' | '.join(q_modes_choice) +
                            ' (default: kernel-wise)')
    args = parser.parse_args()

    torch.manual_seed(args.manual_seed)
    torch.cuda.manual_seed_all(args.manual_seed)
    np.random.seed(args.manual_seed)
    random.seed(args.manual_seed)  # 设置随机种子

    args.save = 'ADmix-cifar100-{}-{}-{}-{}'.format(args.net, args.save, args.quan_mode[-3:], time.strftime("%Y%m%d-%H%M%S"))
    writer = SummaryWriter(log_dir=os.path.join(
            args.log_dir, args.save))
    input_tensor = torch.Tensor(1, 3, 32, 32).cuda()

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
        format=log_format, datefmt='%m/%d %I:%M:%S %p')

    fh = logging.FileHandler(os.path.join('{}/{}/log.txt'.format(args.log_dir, args.save)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    logging.getLogger().setLevel(logging.INFO)

    net = get_network(args)
    if args.pretrained:
        checkpoint = torch.load(os.path.join(args.log_dir, '{}-{}-models/model_best.pth.tar'.format('cifar100', args.net)))
        net.load_state_dict(checkpoint['net'])

    flops, params = get_model_complexity_info(net, (3,32,32), print_per_layer_stat=False)
    logging.info('the original model {} flops is {} and its params is {} '.format(args.net, flops, params))

    replace_conv_recursively(net, args.quan_mode, args)

    logging.info('args = %s', args)
    

    CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
    CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)
    #data preprocessing:
    cifar100_training_loader = get_training_dataloader(
        CIFAR100_TRAIN_MEAN,
        CIFAR100_TRAIN_STD,
        num_workers=args.workers,
        batch_size=args.batch_size,
        shuffle=True
    )

    cifar100_test_loader = get_test_dataloader(
        CIFAR100_TRAIN_MEAN,
        CIFAR100_TRAIN_STD,
        num_workers=args.workers,
        batch_size=args.batch_size,
        shuffle=True
    )

    loss_function = nn.CrossEntropyLoss()
    alphaparams, xmaxparams = [], []
    for name, param in net.named_parameters():
        if not param.requires_grad:
            continue
        if 'xmax' in name:
            xmaxparams.append(param)
        elif 'alpha' in name:
            alphaparams.append(param)
        else:
            xmaxparams.append(param)
            alphaparams.append(param)

    alphaparams = [{'params': alphaparams, 'weight_decay': 5e-4}]
    xmaxparams = [{'params': xmaxparams, 'weight_decay': 5e-4}]

    alphaoptimizer = optim.SGD(alphaparams, lr=args.learning_rate, momentum=0.9, weight_decay=5e-4)
    alphascheduler = optim.lr_scheduler.MultiStepLR(alphaoptimizer, milestones=args.milestones, gamma=0.2) #learning rate decay
    iter_per_epoch = len(cifar100_training_loader)
    alphawarmup_scheduler = WarmUpLR(alphaoptimizer, iter_per_epoch * args.warm)
    xmaxoptimizer = optim.SGD(xmaxparams, args.learning_rate, momentum=0.9, weight_decay=5e-4)
    xmaxscheduler = optim.lr_scheduler.MultiStepLR(xmaxoptimizer, milestones=args.milestones, gamma=0.2)
    xmaxwarmup_scheduler = WarmUpLR(xmaxoptimizer, iter_per_epoch * args.warm)

    #use tensorboard
    if not os.path.exists(args.log_dir):
        os.mkdir(args.log_dir)

    #since tensorboard can't overwrite old values
    #so the only way is to create a new tensorboard log

    best_acc = 0.0
    args.iter, args.alpha_start_epoch, args.mix_start_epoch = 0, 0, 0
    if args.resume:
        if os.path.isfile('{}/{}'.format(args.log_dir, args.save_model)):
            logging.info("=> loading checkpoint '{}/{}'".format(args.log_dir, args.save_model))
            checkpoint = torch.load(os.path.join(args.log_dir, args.save_model))
            logging.info('load best training file to test acc...')
            net.load_state_dict(checkpoint['net'])
            logging.info('best acc is {:0.2f}'.format(checkpoint['acc']))
            best_acc = checkpoint['acc']
            if 'alpha' in args.save_model:
                args.alpha_start_epoch = checkpoint['epoch']
                args.mix_start_epoch = args.alter_epoch * (args.alpha_start_epoch // args.alter_epoch)
            else:
                args.mix_start_epoch = checkpoint['epoch']
                args.alpha_start_epoch = args.alter_epoch * (args.mix_start_epoch // args.alter_epoch + 1)
        else:
            logging.info("=> no checkpoint found at '{}/{}'".format(args.log_dir, args.save_model))
            raise Exception('No such model saved !')

    if args.evaluate:
        acc = eval_training(net, cifar100_test_loader, args, writer, loss_function, 0)
        logging.info('the final best acc is {}'.format(best_acc))
        return 

    whole_alter = math.ceil(args.epochs / args.alter_epoch)
    for iter in range(args.iter, whole_alter):
        logging.info('the iter in {}'.format(iter))
        while args.alpha_start_epoch < min((iter+1) * args.alter_epoch, args.epochs):
            train(net, cifar100_training_loader, args, alphaoptimizer, args.alpha_start_epoch, writer, alphawarmup_scheduler, loss_function)
            acc = eval_training(net, cifar100_test_loader, args, writer, loss_function, args.alpha_start_epoch, tb=False)
            if args.alpha_start_epoch > args.warm:
                alphascheduler.step(args.alpha_start_epoch)
            args.alpha_start_epoch += 1

            is_best = False
            if acc > best_acc:
                best_acc = acc
                logging.info('the best acc is {} in epoch {}'.format(best_acc, args.alpha_start_epoch))
                is_best = True
                flops, params = get_model_complexity_info(net, (3,32,32), print_per_layer_stat=False)
                logging.info('the model after quantized flops is {} and its params is {} '.format(flops, params))
            save_checkpoint({
                'epoch': args.alpha_start_epoch,
                'net': net.state_dict(),
                'acc': best_acc,
                'optimizer': alphaoptimizer.state_dict(),
            }, is_best, save='logger/alpha-{}-{}-{}-models'.format('cifar100', args.net, args.quan_mode[-3:]))

        while args.mix_start_epoch < min((iter+1) * args.alter_epoch, args.epochs):
            train(net, cifar100_training_loader, args, xmaxoptimizer, args.mix_start_epoch, writer, xmaxwarmup_scheduler, loss_function)
            acc = eval_training(net, cifar100_test_loader, args, writer, loss_function, args.mix_start_epoch, tb=False)
            if args.mix_start_epoch > args.warm:
                xmaxscheduler.step(args.mix_start_epoch)
            args.mix_start_epoch += 1

            is_best = False
            if acc > best_acc:
                best_acc = acc
                logging.info('the best acc is {} in epoch {}'.format(best_acc, args.mix_start_epoch))
                is_best = True
                flops, params = get_model_complexity_info(net, (3,32,32), print_per_layer_stat=False)
                logging.info('the model after quantized flops is {} and its params is {} '.format(flops, params))
            save_checkpoint({
                'epoch': args.mix_start_epoch,
                'net': net.state_dict(),
                'acc': best_acc,
                'optimizer': xmaxoptimizer.state_dict(),
            }, is_best, save='logger/mix-{}-{}-{}-models'.format('cifar100', args.net, args.quan_mode[-3:]))


    flops, params = get_model_complexity_info(net, (3,32,32), print_per_layer_stat=False)
    logging.info('the model after quantized flops is {} and its params is {} '.format(flops, params))
    writer.add_graph(net, input_tensor)
    logging.info('the final best acc is {}'.format(best_acc))
    writer.close()
コード例 #4
0
def main():
    # Model
    logging.info('==> Building model {} ..'.format(args.arch))

    net = models.__dict__[args.arch]()

    stage_blocks = []
    block = ['first_']
    for i in range(len(net.layers)):
        if net.layers[i].conv2.stride == 2:
            stage_blocks.append(block)
            block = []
        block.append('layers.{}'.format(i))

    stage_blocks.append(block)
    last_block = ['last_', 'linear']

    net = net.to(device)
    if device == 'cuda':
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = True

    # Load checkpoint.
    if args.pretrained:
        logging.info('==> Resuming from checkpoint..')
        assert os.path.isdir(
            'checkpoint'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('./checkpoint/{}-ckpt.pth'.format(args.arch))
        net.load_state_dict(checkpoint['net'])

    net = replace_conv_recursively(net, args.quan_mode, args)

    criterion = nn.CrossEntropyLoss()

    if args.evaluate:
        validate_loss, validate_top1 = validate(net,
                                                criterion,
                                                data_loader=testloader,
                                                epoch=0,
                                                valid=True)
        return

    if args.kd_ratio > 0:  # 设置teacher model (supernet),实际上是最大的网络
        args.teacher_model = models.__dict__[args.teacher_arch]()
        args.teacher_model = torch.nn.DataParallel(args.teacher_model)
        checkpoint = torch.load('./checkpoint/{}-ckpt.pth'.format(
            args.teacher_arch))
        args.teacher_model.load_state_dict(checkpoint['net'])

    best_acc1 = 0
    args.initial_start_epoch, args.start_epoch = 0, 0
    if args.resume:
        # Load checkpoint.
        logging.info('==> Resuming from Quantized checkpoint..')
        assert os.path.isdir(
            'checkpoint'), 'Error: no checkpoint directory found!'
        checkpoint = torch.load('logger/{}'.format(args.save_model))
        net.load_state_dict(checkpoint['net'])
        best_acc1 = checkpoint['acc']
        if 'initial' in args.save_model:
            args.initial_start_epoch = checkpoint['initial_epoch'] + 1
        else:
            args.initial_start_epoch = args.initial_epochs
            args.start_epoch = checkpoint['epoch'] + 1

    if args.quan_mode != 'Conv2dDPQ':
        flops, params = get_model_complexity_info(net, (3, 32, 32))
        logging.info(
            'the total flops of {} is : {} and whole params is : {}'.format(
                args.arch, flops, params))

    optimizers, schedulers = [], []
    for block in stage_blocks:
        params = add_weight_decay(net,
                                  weight_decay=args.weight_decay,
                                  skip_keys=['expand_', 'running_scale'],
                                  grads=block)
        optimizer = torch.optim.SGD(params,
                                    args.initial_lr,
                                    momentum=args.momentum)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.initial_epochs)
        optimizers.append(optimizer)
        schedulers.append(scheduler)

    params = add_weight_decay(net,
                              weight_decay=args.weight_decay,
                              skip_keys=['expand_', 'running_scale'],
                              grads=last_block)
    final_optimizer = optim.SGD(params,
                                args.initial_lr,
                                momentum=args.momentum)
    final_scheduler = optim.lr_scheduler.CosineAnnealingLR(
        final_optimizer, T_max=args.initial_epochs)

    fin_optimizer = optim.SGD(net.parameters(),
                              lr=args.lr,
                              momentum=0.9,
                              weight_decay=5e-4)
    fin_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(fin_optimizer,
                                                               T_max=200)

    with open('cifar-feature-quan.txt', 'w+') as f:
        for epoch in range(args.initial_start_epoch, args.initial_epochs):
            initial_blocks(net, args.teacher_model, trainloader, optimizers,
                           epoch)
            train_loss, train_top1 = finetune(net, args.teacher_model, args,
                                              trainloader, final_optimizer,
                                              epoch, criterion)
            validate_loss, validate_top1 = validate(net,
                                                    criterion,
                                                    data_loader=testloader,
                                                    epoch=epoch,
                                                    valid=True)
            logging.info(
                'the soft train loss is : {} ; For Train !  the top1 accuracy is : {} ;'
                .format(train_loss, train_top1))
            logging.info(
                'the validate loss is : {} ; For Validate !  the top1 accuracy is : {} ;'
                .format(validate_loss, validate_top1))
            writer.add_scalars('Inital-block-Loss/Training-Validate', {
                'train_soft_loss': train_loss,
                'validate_loss': validate_loss
            }, epoch + 1)
            writer.add_scalars('Inital-block-Top1/Training-Validate', {
                'train_acc1': train_top1,
                'validate_acc1': validate_top1
            }, epoch + 1)
            writer.add_scalars(
                'Learning-Rate-For-Initial', {
                    'basic optimizer':
                    final_optimizer.state_dict()['param_groups'][0]['lr'],
                }, epoch + 1)
            is_best = False
            if validate_top1 > best_acc1:
                best_acc1 = validate_top1
                is_best = True
                logging.info(
                    'the best model top1 is : {} and its epoch is {} !'.format(
                        best_acc1, epoch))
            save_checkpoint(
                {
                    'initial_epoch': epoch,
                    'state_dict': net.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': final_optimizer.state_dict(),
                    'scheduler': final_scheduler.state_dict(),
                },
                is_best,
                save='logger/{}-{}-{}-initial-models'.format(
                    'cifar', args.arch, args.quan_mode[-3:]))
            for scheduler in schedulers:
                scheduler.step()
            final_scheduler.step()
        logging.info(
            'After initial from feature, we can get the final soft train loss and the top1 accuracy'
        )
        for epoch in range(args.start_epoch, args.epochs):
            train_loss, train_top1 = finetune(net, args.teacher_model, args,
                                              trainloader, fin_optimizer,
                                              epoch, criterion)
            validate_loss, validate_top1 = validate(net,
                                                    criterion,
                                                    data_loader=testloader,
                                                    epoch=epoch,
                                                    valid=True)
            logging.info(
                'the train loss is : {} ; For Train !  the top1 accuracy is : {} ;'
                .format(train_loss, train_top1))
            logging.info(
                'the validate loss is : {} ; For Validate !  the top1 accuracy is : {} ;'
                .format(validate_loss, validate_top1))
            writer.add_scalars('Quantization-Loss/Training-Validate', {
                'train_loss': train_loss,
                'validate_loss': validate_loss
            }, epoch + 1)
            writer.add_scalars('Quantization-Top1/Training-Validate', {
                'train_acc1': train_top1,
                'validate_acc1': validate_top1
            }, epoch + 1)
            writer.add_scalars(
                'Learning-Rate-For-Finetune', {
                    'basic optimizer':
                    fin_optimizer.state_dict()['param_groups'][0]['lr'],
                }, epoch + 1)
            is_best = False
            if validate_top1 > best_acc1:
                best_acc1 = validate_top1
                is_best = True
                logging.info(
                    'the best model top1 is : {} and its epoch is {} !'.format(
                        best_acc1, epoch))
            save_checkpoint(
                {
                    'epoch': epoch,
                    'state_dict': net.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': fin_optimizer.state_dict(),
                    'scheduler': fin_scheduler.state_dict(),
                },
                is_best,
                save='logger/{}-{}-{}-models'.format('cifar', args.arch,
                                                     args.quan_mode[-3:]))
            fin_scheduler.step()
コード例 #5
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)
    # create model
    if 'efficientnet' in args.arch:  # NEW
        if args.pretrained:
            model = EfficientNet.from_pretrained(args.arch,
                                                 advprop=args.advprop)
            print("=> using pre-trained model '{}'".format(args.arch))
        else:
            print("=> creating model '{}'".format(args.arch))
            model = EfficientNet.from_name(args.arch)
    elif 'mobilenetv3' in args.arch:
        if args.pretrained:
            model = localmodels.__dict__[args.arch]()
            if 'large' in args.arch:
                model.load_state_dict(
                    torch.load('checkpoint/mobilenetv3-large-1cd25616.pth'))
            else:
                model.load_state_dict(
                    torch.load('checkpoint/mobilenetv3-small-55df8e1f.pth'))
            print("=> using pre-trained model '{}'".format(args.arch))
        else:
            print("=> creating model '{}'".format(args.arch))
            model = localmodels.__dict__[args.arch]
    else:
        if args.pretrained:
            print("=> using pre-trained model '{}'".format(args.arch))
            model = models.__dict__[args.arch](pretrained=True)
        else:
            print("=> creating model '{}'".format(args.arch))
            model = models.__dict__[args.arch]()

    model = replace_conv_recursively(model, args.quan_mode, args)

    if args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(args.workers / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    if args.advprop:
        normalize = transforms.Lambda(lambda img: img * 2.0 - 1.0)
    else:
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

    if 'efficientnet' in args.arch:
        image_size = EfficientNet.get_image_size(args.arch)
    else:
        image_size = args.image_size

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    val_transforms = transforms.Compose([
        transforms.Resize(image_size, interpolation=PIL.Image.BICUBIC),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        normalize,
    ])
    print('Using image size', image_size)

    val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
        valdir, val_transforms),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        res = validate(val_loader, model, criterion, args)
        with open('res.txt', 'w') as f:
            print(res, file=f)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train_loss, train_acc1, train_acc5 = train(train_loader, model,
                                                   criterion, optimizer, epoch,
                                                   args)

        # evaluate on validation set
        validate_loss, validate_acc1, validate_acc5 = validate(
            val_loader, model, criterion, args)

        if not (args.distributed and args.local_rank != 0):
            logging.info(
                'the soft train loss is : {} ; For Train !  the top1 accuracy is : {} ; then the top5 accuracy is : {}'
                .format(train_loss, train_acc1, train_acc5))
            logging.info(
                'the validate loss is : {} ; For Validate !  the top1 accuracy is : {} ; then the top5 accuracy is : {}'
                .format(validate_loss, validate_acc1, validate_acc5))
            args.writer.add_scalars('Quantization-Loss/Training-Validate', {
                'train_loss': train_loss,
                'validate_loss': validate_loss
            }, epoch + 1)
            args.writer.add_scalars('Quantization-Top1/Training-Validate', {
                'train_acc1': train_acc1,
                'validate_acc1': validate_acc1
            }, epoch + 1)
            args.writer.add_scalars('Quantization-Top5/Training-Validate', {
                'train_acc5': train_acc5,
                'validate_acc5': validate_acc5
            }, epoch + 1)
            args.writer.add_scalars('Learning-Rate-For-Quan', {
                'basic optimizer':
                optimizer.state_dict()['param_groups'][0]['lr']
            }, epoch + 1)
        is_best = False
        if validate_acc1 > best_acc1:
            best_acc1 = validate_acc1
            is_best = True
            logging.info(
                'the best model top1 is : {} and its epoch is {} !'.format(
                    best_acc1, epoch))
            # remember best acc@1 and save checkpoint
        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                },
                is_best,
                save='logger/{}-{}-models'.format(args.arch,
                                                  args.quan_mode[-3:]))