示例#1
0
    def __init__(self):
        global args

        if 'cuda' in args.device and torch.cuda.is_available():
            torch.cuda.set_device(args.device_ids[0])
            cudnn.benchmark = True
        else:
            args.device_ids = None

        # create model
        print("=> using pre-trained model '{}'".format(args.arch))

        self.model = models.__dict__[args.arch](pretrained=True)

        # BatchNorm folding
        print("Perform BN folding")
        search_absorbe_bn(self.model)

        self.model.to(args.device)

        # define loss function (criterion) and optimizer
        self.criterion = nn.CrossEntropyLoss()
        self.criterion.to(args.device)

        cudnn.benchmark = True

        # Data loading code
        valdir = os.path.join(args.data, '')

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        resize = 256 if args.arch != 'inception_v3' else 299
        crop_size = 224 if args.arch != 'inception_v3' else 299
        tfs = [
            transforms.Resize(resize),
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            normalize,
        ]

        self.val_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(valdir, transforms.Compose(tfs)),
            batch_size=args.batch_size,
            shuffle=(True),
            num_workers=args.workers,
            pin_memory=True)
示例#2
0
    def eval_pnorm(p):
        args.qtype = 'lp_norm'
        args.lp = p
        # Fix the seed
        random.seed(args.seed)
        if not args.dont_fix_np_seed:
            np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        inf_model = CnnModel(args.arch,
                             custom_resnet,
                             custom_inception,
                             args.pretrained,
                             args.dataset,
                             args.gpu_ids,
                             args.datapath,
                             batch_size=args.batch_size,
                             shuffle=True,
                             workers=args.workers,
                             print_freq=args.print_freq,
                             cal_batch_size=args.cal_batch_size,
                             cal_set_size=args.cal_set_size,
                             args=args)

        if args.bn_folding:
            print(
                "Applying batch-norm folding ahead of post-training quantization"
            )
            # pdb.set_trace()
            from utils.absorb_bn import search_absorbe_bn
            search_absorbe_bn(inf_model.model)
        mq = ModelQuantizer(inf_model.model, args, layers, replacement_factory)
        loss = inf_model.evaluate_calibration()
        point = mq.get_clipping()

        # evaluate
        acc = inf_model.validate()

        del inf_model
        del mq

        return point, loss, acc
示例#3
0
def process_model(arch, num_bits, base_dir, task='quantize'):
    model = models.__dict__[arch](pretrained=True)
    search_absorbe_bn(model)

    # Quantize model by kmeans non uniform quantization
    model_qkmeans = copy.deepcopy(model)
    if task == 'quantize':
        quantize_model_parameters(model_qkmeans, num_bits=num_bits)
    elif task == 'clip':
        clip_model_parameters(model_qkmeans, num_bits=num_bits)
    else:
        print("Invalid argument task=%s" % task)
        exit(-1)

    # Save model to home dir
    model_path = os.path.join(base_dir, 'models')
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    model_path = os.path.join(model_path,
                              arch + ('_kmeans%dbit.pt' % num_bits))
    print("Saving quantized model to %s" % model_path)
    torch.save(model_qkmeans, model_path)

    # Per channel bias correction
    model_bcorr = copy.deepcopy(model_qkmeans)
    p_km = [np for np in model_bcorr.named_parameters()]
    p_orig = [np for np in model.named_parameters()]
    for i in tqdm(range(len(p_km))):
        if not is_ignored(p_km[i][0], p_km[i][1]):
            w_km = p_km[i][1]
            w_orig = p_orig[i][1]
            mean_delta = w_km.view(w_km.shape[0], -1).mean(
                dim=-1) - w_orig.view(w_orig.shape[0], -1).mean(dim=-1)
            p_km[i][1].data = (w_km.view(w_km.shape[0], -1) -
                               mean_delta.view(mean_delta.shape[0], 1)).view(
                                   w_orig.shape)

    model_path = model_path.split('.')[0] + '_bcorr.pt'
    print("Saving quantized model with bias correction to %s" % model_path)
    torch.save(model_bcorr, model_path)
def main_worker(args, ml_logger):
    global best_acc1

    if args.gpu_ids is not None:
        print("Use GPU: {} for training".format(args.gpu_ids))

    if args.log_stats:
        from utils.stats_trucker import StatsTrucker as ST
        ST("W{}A{}".format(args.bit_weights, args.bit_act))

    if 'resnet' in args.arch and args.custom_resnet:
        model = custom_resnet(arch=args.arch,
                              pretrained=args.pretrained,
                              depth=arch2depth(args.arch),
                              dataset=args.dataset)
    elif 'inception_v3' in args.arch and args.custom_inception:
        model = custom_inception(pretrained=args.pretrained)
    else:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=args.pretrained)

    device = torch.device('cuda:{}'.format(args.gpu_ids[0]))
    cudnn.benchmark = True

    torch.cuda.set_device(args.gpu_ids[0])
    model = model.to(device)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, device)
            args.start_epoch = checkpoint['epoch']
            # best_acc1 = checkpoint['best_acc1']
            # best_acc1 may be from a checkpoint from a different GPU
            # best_acc1 = best_acc1.to(device)
            checkpoint['state_dict'] = {
                normalize_module_name(k): v
                for k, v in checkpoint['state_dict'].items()
            }
            model.load_state_dict(checkpoint['state_dict'], strict=False)
            # optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if len(args.gpu_ids) > 1:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features,
                                                   args.gpu_ids)
        else:
            model = torch.nn.DataParallel(model, args.gpu_ids)

    default_transform = {
        'train': get_transform(args.dataset, augment=True),
        'eval': get_transform(args.dataset, augment=False)
    }

    val_data = get_dataset(args.dataset, 'val', default_transform['eval'])
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().to(device)

    train_data = get_dataset(args.dataset, 'train', default_transform['train'])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    # TODO: replace this call by initialization on small subset of training data
    # TODO: enable for activations
    # validate(val_loader, model, criterion, args, device)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    lr_scheduler = StepLR(optimizer, step_size=args.lr_step, gamma=0.1)

    mq = None
    if args.quantize:
        if args.bn_folding:
            print(
                "Applying batch-norm folding ahead of post-training quantization"
            )
            from utils.absorb_bn import search_absorbe_bn
            search_absorbe_bn(model)

        all_convs = [
            n for n, m in model.named_modules() if isinstance(m, nn.Conv2d)
        ]
        # all_convs = [l for l in all_convs if 'downsample' not in l]
        all_relu = [
            n for n, m in model.named_modules() if isinstance(m, nn.ReLU)
        ]
        all_relu6 = [
            n for n, m in model.named_modules() if isinstance(m, nn.ReLU6)
        ]
        layers = all_relu[1:-1] + all_relu6[1:-1] + all_convs[1:]
        replacement_factory = {
            nn.ReLU: ActivationModuleWrapper,
            nn.ReLU6: ActivationModuleWrapper,
            nn.Conv2d: ParameterModuleWrapper
        }
        mq = ModelQuantizer(
            model, args, layers, replacement_factory,
            OptimizerBridge(optimizer,
                            settings={
                                'algo': 'SGD',
                                'dataset': args.dataset
                            }))

        if args.resume:
            # Load quantization parameters from state dict
            mq.load_state_dict(checkpoint['state_dict'])

        mq.log_quantizer_state(ml_logger, -1)

        if args.model_freeze:
            mq.freeze()

    if args.evaluate:
        if args.log_stats:
            mean = []
            var = []
            skew = []
            kurt = []
            for n, p in model.named_parameters():
                if n.replace('.weight', '') in all_convs[1:]:
                    mu = p.mean()
                    std = p.std()
                    mean.append((n, mu.item()))
                    var.append((n, (std**2).item()))
                    skew.append((n, torch.mean(((p - mu) / std)**3).item()))
                    kurt.append((n, torch.mean(((p - mu) / std)**4).item()))
            for i in range(len(mean)):
                ml_logger.log_metric(mean[i][0] + '.mean', mean[i][1])
                ml_logger.log_metric(var[i][0] + '.var', var[i][1])
                ml_logger.log_metric(skew[i][0] + '.skewness', skew[i][1])
                ml_logger.log_metric(kurt[i][0] + '.kurtosis', kurt[i][1])

            ml_logger.log_metric('weight_mean', np.mean([s[1] for s in mean]))
            ml_logger.log_metric('weight_var', np.mean([s[1] for s in var]))
            ml_logger.log_metric('weight_skewness',
                                 np.mean([s[1] for s in skew]))
            ml_logger.log_metric('weight_kurtosis',
                                 np.mean([s[1] for s in kurt]))

        acc = validate(val_loader, model, criterion, args, device)
        ml_logger.log_metric('Val Acc1', acc)
        if args.log_stats:
            stats = ST().get_stats()
            for s in stats:
                ml_logger.log_metric(s, np.mean(stats[s]))
        return

    # evaluate on validation set
    acc1 = validate(val_loader, model, criterion, args, device)
    ml_logger.log_metric('Val Acc1', acc1, -1)

    # evaluate with k-means quantization
    # if args.model_freeze:
    # with mq.disable():
    #     acc1_nq = validate(val_loader, model, criterion, args, device)
    #     ml_logger.log_metric('Val Acc1 fp32', acc1_nq, -1)

    for epoch in range(0, args.epochs):
        # train for one epoch
        print('Timestamp Start epoch: {:%Y-%m-%d %H:%M:%S}'.format(
            datetime.datetime.now()))
        train(train_loader, model, criterion, optimizer, epoch, args, device,
              ml_logger, val_loader, mq)
        print('Timestamp End epoch: {:%Y-%m-%d %H:%M:%S}'.format(
            datetime.datetime.now()))

        if not args.lr_freeze:
            lr_scheduler.step()

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args, device)
        ml_logger.log_metric('Val Acc1', acc1, step='auto')

        # evaluate with k-means quantization
        # if args.model_freeze:
        # with mq.quantization_method('kmeans'):
        #     acc1_kmeans = validate(val_loader, model, criterion, args, device)
        #     ml_logger.log_metric('Val Acc1 kmeans', acc1_kmeans, epoch)

        # with mq.disable():
        #     acc1_nq = validate(val_loader, model, criterion, args, device)
        #     ml_logger.log_metric('Val Acc1 fp32', acc1_nq,  step='auto')

        if args.quantize:
            mq.log_quantizer_state(ml_logger, epoch)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        save_checkpoint(
            {
                'epoch':
                epoch + 1,
                'arch':
                args.arch,
                'state_dict':
                model.state_dict()
                if len(args.gpu_ids) == 1 else model.module.state_dict(),
                'best_acc1':
                best_acc1,
                'optimizer':
                optimizer.state_dict(),
            }, is_best)
def main(args, ml_logger):
    # Fix the seed
    random.seed(args.seed)
    if not args.dont_fix_np_seed:
        np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    if args.tag is not None:
        ml_logger.mlflow.log_param('tag', args.tag)

    args.qtype = 'max_static'
    # create model
    # Always enable shuffling to avoid issues where we get bad results due to weak statistics
    custom_resnet = True
    custom_inception = True
    inf_model = CnnModel(args.arch,
                         custom_resnet,
                         custom_inception,
                         args.pretrained,
                         args.dataset,
                         args.gpu_ids,
                         args.datapath,
                         batch_size=args.batch_size,
                         shuffle=True,
                         workers=args.workers,
                         print_freq=args.print_freq,
                         cal_batch_size=args.cal_batch_size,
                         cal_set_size=args.cal_set_size,
                         args=args)

    layers = []
    # TODO: make it more generic
    if args.bit_weights is not None:
        layers += [
            n for n, m in inf_model.model.named_modules()
            if isinstance(m, nn.Conv2d)
        ][1:-1]
    if args.bit_act is not None:
        layers += [
            n for n, m in inf_model.model.named_modules()
            if isinstance(m, nn.ReLU)
        ][1:-1]
    if args.bit_act is not None and 'mobilenet' in args.arch:
        layers += [
            n for n, m in inf_model.model.named_modules()
            if isinstance(m, nn.ReLU6)
        ][1:-1]

    replacement_factory = {
        nn.ReLU: ActivationModuleWrapperPost,
        nn.ReLU6: ActivationModuleWrapperPost,
        nn.Conv2d: ParameterModuleWrapperPost
    }

    mq = ModelQuantizer(inf_model.model, args, layers, replacement_factory)
    loss = inf_model.evaluate_calibration()

    # evaluate
    max_acc = inf_model.validate()
    max_point = mq.get_clipping()
    ml_logger.log_metric('Loss max', loss.item(), step='auto')
    ml_logger.log_metric('Acc max', max_acc, step='auto')
    data = {'max': {'alpha': max_point.cpu().numpy(), 'loss': loss.item()}}
    print("max loss: {:.4f}, max_acc: {:.4f}".format(loss.item(), max_acc))

    def eval_pnorm(p):
        args.qtype = 'lp_norm'
        args.lp = p
        # Fix the seed
        random.seed(args.seed)
        if not args.dont_fix_np_seed:
            np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        inf_model = CnnModel(
            args.arch,
            custom_resnet,
            custom_inception,
            args.pretrained,
            args.dataset,
            args.gpu_ids,
            args.datapath,
            batch_size=args.batch_size,
            shuffle=True,
            workers=args.workers,
            print_freq=args.print_freq,
            cal_batch_size=args.cal_batch_size,
            cal_set_size=args.cal_set_size,
            args=args,
        )

        mq = ModelQuantizer(inf_model.model, args, layers, replacement_factory)
        loss = inf_model.evaluate_calibration()
        point = mq.get_clipping()

        # evaluate
        acc = inf_model.validate()

        del inf_model
        del mq

        return point, loss, acc

    del inf_model
    del mq

    print("Evaluate L2 norm optimization")
    l2_point, l2_loss, l2_acc = eval_pnorm(2.)
    print("loss l2: {:.4f}".format(l2_loss.item()))
    ml_logger.log_metric('Loss l2', l2_loss.item(), step='auto')
    ml_logger.log_metric('Acc l2', l2_acc, step='auto')
    data['l2'] = {
        'alpha': l2_point.cpu().numpy(),
        'loss': l2_loss.item(),
        'acc': l2_acc
    }

    print("Evaluate L2.5 norm optimization")
    l25_point, l25_loss, l25_acc = eval_pnorm(2.5)
    print("loss l2.5: {:.4f}".format(l25_loss.item()))
    ml_logger.log_metric('Loss l2.5', l25_loss.item(), step='auto')
    ml_logger.log_metric('Acc l2.5', l25_acc, step='auto')
    data['l2.5'] = {
        'alpha': l25_point.cpu().numpy(),
        'loss': l25_loss.item(),
        'acc': l25_acc
    }

    print("Evaluate L3 norm optimization")
    l3_point, l3_loss, l3_acc = eval_pnorm(3.)
    print("loss l3: {:.4f}".format(l3_loss.item()))
    ml_logger.log_metric('Loss l3', l3_loss.item(), step='auto')
    ml_logger.log_metric('Acc l3', l3_acc, step='auto')
    data['l3'] = {
        'alpha': l3_point.cpu().numpy(),
        'loss': l3_loss.item(),
        'acc': l3_acc
    }

    # Interpolate optimal p
    xp = np.linspace(1, 5, 50)
    z = np.polyfit([2, 2.5, 3], [l2_acc, l25_acc, l3_acc], 2)
    y = np.poly1d(z)
    p_intr = xp[np.argmax(y(xp))]
    print("p intr: {:.2f}".format(p_intr))
    ml_logger.log_metric('p intr', p_intr, step='auto')

    args.qtype = 'lp_norm'
    args.lp = p_intr
    # Fix the seed
    random.seed(args.seed)
    if not args.dont_fix_np_seed:
        np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    inf_model = CnnModel(args.arch,
                         custom_resnet,
                         custom_inception,
                         args.pretrained,
                         args.dataset,
                         args.gpu_ids,
                         args.datapath,
                         batch_size=args.batch_size,
                         shuffle=True,
                         workers=args.workers,
                         print_freq=args.print_freq,
                         cal_batch_size=args.cal_batch_size,
                         cal_set_size=args.cal_set_size,
                         args=args)

    if args.bn_folding:
        print(
            "Applying batch-norm folding ahead of post-training quantization")
        # pdb.set_trace()
        from utils.absorb_bn import search_absorbe_bn
        search_absorbe_bn(inf_model.model)
    mq = ModelQuantizer(inf_model.model, args, layers, replacement_factory)

    # Evaluate with optimal p
    lp_loss = inf_model.evaluate_calibration()
    lp_point = mq.get_clipping()
    # evaluate
    lp_acc = inf_model.validate()

    print("loss p intr: {:.4f}".format(lp_loss.item()))
    ml_logger.log_metric('Loss p intr', lp_loss.item(), step='auto')
    ml_logger.log_metric('Acc p intr', lp_acc, step='auto')

    global _eval_count, _min_loss
    _min_loss = lp_loss.item()

    idx = np.argmax([l2_acc, l25_acc, l3_acc, lp_acc])
    init = [l2_point, l25_point, l3_point, lp_point][idx]

    # run optimizer
    min_options = {}
    if args.maxiter is not None:
        min_options['maxiter'] = args.maxiter
    if args.maxfev is not None:
        min_options['maxfev'] = args.maxfev

    _iter = count(0)

    def local_search_callback(x):
        it = next(_iter)
        mq.set_clipping(x, inf_model.device)
        loss = inf_model.evaluate_calibration()
        print("\n[{}]: Local search callback".format(it))
        print("loss: {:.4f}\n".format(loss.item()))
        print(x)
        ml_logger.log_metric('Loss {}'.format(args.min_method),
                             loss.item(),
                             step='auto')

        # evaluate
        acc = inf_model.validate()
        ml_logger.log_metric('Acc {}'.format(args.min_method),
                             acc,
                             step='auto')

    args.min_method = "Powell"
    method = coord_descent if args.min_method == 'CD' else args.min_method
    res = opt.minimize(
        lambda scales: evaluate_calibration_clipped(scales, inf_model, mq),
        init.cpu().numpy(),
        method=method,
        options=min_options,
        callback=local_search_callback)

    print(res)

    scales = res.x
    mq.set_clipping(scales, inf_model.device)
    loss = inf_model.evaluate_calibration()
    ml_logger.log_metric('Loss {}'.format(args.min_method),
                         loss.item(),
                         step='auto')

    # evaluate
    acc = inf_model.validate()
    ml_logger.log_metric('Acc {}'.format(args.min_method), acc, step='auto')
    data['powell'] = {'alpha': scales, 'loss': loss.item(), 'acc': acc}

    print("Starting coordinate descent")
    args.min_method = "CD"
    min_options[
        'maxiter'] = 1  # Perform only one iteration of coordinate descent to avoid divergence
    _iter = count(0)
    global _eval_count
    _eval_count = count(0)
    _min_loss = lp_loss.item()
    mq.set_clipping(init, inf_model.device)
    # Run coordinate descent for comparison
    method = coord_descent
    res = opt.minimize(
        lambda scales: evaluate_calibration_clipped(scales, inf_model, mq),
        init.cpu().numpy(),
        method=method,
        options=min_options,
        callback=local_search_callback)

    print(res)

    scales = res.x
    mq.set_clipping(scales, inf_model.device)
    loss = inf_model.evaluate_calibration()
    ml_logger.log_metric('Loss {}'.format("CD"), loss.item(), step='auto')

    # evaluate
    acc = inf_model.validate()
    ml_logger.log_metric('Acc {}'.format("CD"), acc, step='auto')
    data['cd'] = {'alpha': scales, 'loss': loss.item(), 'acc': acc}

    # save scales
    f_name = "scales_{}_W{}A{}.pkl".format(args.arch, args.bit_weights,
                                           args.bit_act)
    f = open(os.path.join(proj_root_dir, 'data', f_name), 'wb')
    pickle.dump(data, f)
    f.close()
    print("Data saved to {}".format(f_name))
def main():
    global args, best_prec1

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    if 'cuda' in args.device and torch.cuda.is_available():
        if args.seed is not None:
            torch.cuda.manual_seed_all(args.seed)
        torch.cuda.set_device(args.device_ids[0])
        cudnn.benchmark = True
    else:
        args.device_ids = None

    # create model
    print("=> using pre-trained model '{}'".format(args.arch))
    model = models.__dict__[args.arch](pretrained=True)
    model.to(args.device)

    if args.device_ids and len(args.device_ids) > 1 and args.arch != 'shufflenet':
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features, args.device_ids)
        else:
            model = torch.nn.DataParallel(model, args.device_ids)

    # BatchNorm folding
    if 'resnet' in args.arch or args.arch == 'vgg16_bn' or args.arch == 'inception_v3':
        print("Perform BN folding")
        search_absorbe_bn(model)
        QM().bn_folding = True

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()
    criterion.to(args.device)

    cudnn.benchmark = True

    # Data loading code
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    resize = 256 if args.arch != 'inception_v3' else 299
    crop_size = 224 if args.arch != 'inception_v3' else 299

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(resize),
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.eval_precision:
        elog = EvalLog(['dtype', 'val_prec1', 'val_prec5'])
        print("\nFloat32 no quantization")
        QM().disable()
        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion)
        elog.log('fp32', val_prec1, val_prec5)
        logging.info('\nValidation Loss {val_loss:.4f} \t'
                     'Validation Prec@1 {val_prec1:.3f} \t'
                     'Validation Prec@5 {val_prec5:.3f} \n'
                     .format(val_loss=val_loss, val_prec1=val_prec1, val_prec5=val_prec5))
        print("--------------------------------------------------------------------------")

        for q in [8, 7, 6, 5, 4]:
            args.qtype = 'int{}'.format(q)
            print("\nQuantize to %s" % args.qtype)
            QM().quantize = True
            QM().reload(args, qparams)
            val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion)
            elog.log(args.qtype, val_prec1, val_prec5)
            logging.info('\nValidation Loss {val_loss:.4f} \t'
                         'Validation Prec@1 {val_prec1:.3f} \t'
                         'Validation Prec@5 {val_prec5:.3f} \n'
                         .format(val_loss=val_loss, val_prec1=val_prec1, val_prec5=val_prec5))
            print("--------------------------------------------------------------------------")
        print(elog)
        elog.save('results/precision/%s_%s_clipping.csv' % (args.arch, args.threshold))
    elif args.custom_test:
        log_name = 'results/custom_test/%s_max_mse_%s_cliping_layer_selection.csv' % (args.arch, args.threshold)
        elog = EvalLog(['num_8bit_layers', 'indexes', 'val_prec1', 'val_prec5'], log_name, auto_save=True)
        for i in range(len(max_mse_order_id)+1):
            _8bit_layers = ['conv0_activation'] + max_mse_order_id[0:i]
            print("it: %d, 8 bit layers: %d" % (i, len(_8bit_layers)))
            QM().set_8bit_list(_8bit_layers)
            val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion)
            elog.log(i+1, str(_8bit_layers), val_prec1, val_prec5)
        print(elog)
    else:
        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion)
    def __init__(self, ml_logger=None):
        self.ml_logger = ml_logger
        global args, best_prec1

        if args.seed is not None:
            random.seed(args.seed)
            torch.manual_seed(args.seed)
            cudnn.deterministic = True
            warnings.warn('You have chosen to seed training. '
                          'This will turn on the CUDNN deterministic setting, '
                          'which can slow down your training considerably! '
                          'You may see unexpected behavior when restarting '
                          'from checkpoints.')

        if 'cuda' in args.device and torch.cuda.is_available():
            if args.seed is not None:
                torch.cuda.manual_seed_all(args.seed)
            torch.cuda.set_device(args.device_ids[0])
            cudnn.benchmark = True
        else:
            args.device_ids = None

        # create model
        print("=> using pre-trained model '{}'".format(args.arch))
        if args.arch == 'shufflenet':
            import models.ShuffleNet as shufflenet
            self.model = shufflenet.ShuffleNet(groups=8)
            params = torch.load('ShuffleNet_1g8_Top1_67.408_Top5_87.258.pth.tar')
            self.model = torch.nn.DataParallel(self.model, args.device_ids)
            self.model.load_state_dict(params)
        # elif args.arch == 'mobilenetv2':
        #     from models.MobileNetV2 import MobileNetV2 as mobilenetv2
        #     self.model = mobilenetv2()
        #     params = torch.load('mobilenetv2_Top1_71.806_Top2_90.410.pth.tar')
        #     self.model = torch.nn.DataParallel(self.model, args.device_ids)
        #     self.model.load_state_dict(params)
        # elif args.arch not in models.__dict__ and args.arch in pretrainedmodels.model_names:
        #     self.model = pretrainedmodels.__dict__[args.arch](num_classes=1000, pretrained='imagenet')
        else:
            self.model = models.__dict__[args.arch](pretrained=True)

        set_node_names(self.model)

        # Mark layers before relue for fusing
        if 'resnet' in args.arch:
            resnet_mark_before_relu(self.model)

        # BatchNorm folding
        if 'resnet' in args.arch or args.arch == 'vgg16_bn' or args.arch == 'inception_v3':
            print("Perform BN folding")
            search_absorbe_bn(self.model)
            QM().bn_folding = True

        # if args.qmodel is not None:
        #     model_q_path = os.path.join(os.path.join(home, 'mxt-sim/models'), args.arch + '_lowp_pcq%dbit%s.pt' % (args.qmodel, ('' if args.no_bias_corr else '_bcorr')))
        #     model_q = torch.load(model_q_path)
        #     qldict = set_node_names(model_q, create_ldict=True)
        #     QM().ql_dict = qldict
        #     model_q.to(args.device)
        #     self.model.load_state_dict(model_q.state_dict())
        #     del model_q

        self.model.to(args.device)
        QM().quantize_model(self.model)

        if args.device_ids and len(args.device_ids) > 1 and args.arch != 'shufflenet' and args.arch != 'mobilenetv2':
            if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
                self.model.features = torch.nn.DataParallel(self.model.features, args.device_ids)
            else:
                self.model = torch.nn.DataParallel(self.model, args.device_ids)

        # define loss function (criterion) and optimizer
        self.criterion = nn.CrossEntropyLoss()
        self.criterion.to(args.device)

        cudnn.benchmark = True

        # Data loading code
        valdir = os.path.join(args.data, 'val')

        if args.arch not in models.__dict__ and args.arch in pretrainedmodels.model_names:
            dataparallel = args.device_ids is not None and len(args.device_ids) > 1
            tfs = [mutils.TransformImage(self.model.module if dataparallel else self.model)]
        else:
            normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])
            resize = 256 if args.arch != 'inception_v3' else 299
            crop_size = 224 if args.arch != 'inception_v3' else 299
            tfs = [
                transforms.Resize(resize),
                transforms.CenterCrop(crop_size),
                transforms.ToTensor(),
                normalize,
            ]

        self.val_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(valdir, transforms.Compose(tfs)),
            batch_size=args.batch_size, shuffle=(True if (args.kld_threshold or args.aciq_cal or args.shuffle) else False),
            num_workers=args.workers, pin_memory=True)
示例#8
0
def main_worker(args, ml_logger):
    global best_acc1
    datatime_str = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    suf_name = "_" + args.experiment

    if args.gpu_ids is not None:
        print("Use GPU: {} for training".format(args.gpu_ids))

    if args.log_stats:
        from utils.stats_trucker import StatsTrucker as ST
        ST("W{}A{}".format(args.bit_weights, args.bit_act))

    if 'resnet' in args.arch and args.custom_resnet:
        # pdb.set_trace()
        model = custom_resnet(arch=args.arch,
                              pretrained=args.pretrained,
                              depth=arch2depth(args.arch),
                              dataset=args.dataset)
    elif 'inception_v3' in args.arch and args.custom_inception:
        model = custom_inception(pretrained=args.pretrained)
    else:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=args.pretrained)

    device = torch.device('cuda:{}'.format(args.gpu_ids[0]))
    cudnn.benchmark = True

    torch.cuda.set_device(args.gpu_ids[0])
    model = model.to(device)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, device)
            args.start_epoch = checkpoint['epoch']
            # best_acc1 = checkpoint['best_acc1']
            # best_acc1 may be from a checkpoint from a different GPU
            # best_acc1 = best_acc1.to(device)
            checkpoint['state_dict'] = {
                normalize_module_name(k): v
                for k, v in checkpoint['state_dict'].items()
            }
            model.load_state_dict(checkpoint['state_dict'], strict=False)
            # optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if len(args.gpu_ids) > 1:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features,
                                                   args.gpu_ids)
        else:
            model = torch.nn.DataParallel(model, args.gpu_ids)

    default_transform = {
        'train': get_transform(args.dataset, augment=True),
        'eval': get_transform(args.dataset, augment=False)
    }

    val_data = get_dataset(args.dataset, 'val', default_transform['eval'])
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().to(device)

    train_data = get_dataset(args.dataset, 'train', default_transform['train'])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    # TODO: replace this call by initialization on small subset of training data
    # TODO: enable for activations
    # validate(val_loader, model, criterion, args, device)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    lr_scheduler = StepLR(optimizer, step_size=args.lr_step, gamma=0.1)

    # pdb.set_trace()
    mq = None
    if args.quantize:
        if args.bn_folding:
            print(
                "Applying batch-norm folding ahead of post-training quantization"
            )
            from utils.absorb_bn import search_absorbe_bn
            search_absorbe_bn(model)

        all_convs = [
            n for n, m in model.named_modules() if isinstance(m, nn.Conv2d)
        ]
        # all_convs = [l for l in all_convs if 'downsample' not in l]
        all_relu = [
            n for n, m in model.named_modules() if isinstance(m, nn.ReLU)
        ]
        all_relu6 = [
            n for n, m in model.named_modules() if isinstance(m, nn.ReLU6)
        ]
        layers = all_relu[1:-1] + all_relu6[1:-1] + all_convs[1:]
        replacement_factory = {
            nn.ReLU: ActivationModuleWrapper,
            nn.ReLU6: ActivationModuleWrapper,
            nn.Conv2d: ParameterModuleWrapper
        }
        mq = ModelQuantizer(
            model, args, layers, replacement_factory,
            OptimizerBridge(optimizer,
                            settings={
                                'algo': 'SGD',
                                'dataset': args.dataset
                            }))

        if args.resume:
            # Load quantization parameters from state dict
            mq.load_state_dict(checkpoint['state_dict'])

        mq.log_quantizer_state(ml_logger, -1)

        if args.model_freeze:
            mq.freeze()

    # pdb.set_trace()
    if args.evaluate:
        acc = validate(val_loader, model, criterion, args, device)
        ml_logger.log_metric('Val Acc1', acc)
        return

    # evaluate on validation set
    acc1 = validate(val_loader, model, criterion, args, device)
    ml_logger.log_metric('Val Acc1', acc1, -1)

    # evaluate with k-means quantization
    # if args.model_freeze:
    # with mq.disable():
    #     acc1_nq = validate(val_loader, model, criterion, args, device)
    #     ml_logger.log_metric('Val Acc1 fp32', acc1_nq, -1)

    # pdb.set_trace()
    # Kurtosis regularization on weights tensors
    weight_to_hook = {}
    if args.w_kurtosis:
        if args.weight_name[0] == 'all':
            all_convs = [
                n.replace(".wrapped_module", "") + '.weight'
                for n, m in model.named_modules() if isinstance(m, nn.Conv2d)
            ]
            weight_name = all_convs[1:]
            if args.remove_weight_name:
                for rm_name in args.remove_weight_name:
                    weight_name.remove(rm_name)
        else:
            weight_name = args.weight_name
        for name in weight_name:
            # pdb.set_trace()
            curr_param = fine_weight_tensor_by_name(model, name)
            # if not curr_param:
            #     name = 'float_' + name # QAT name
            #     curr_param = fine_weight_tensor_by_name(self.model, name)
            # if curr_param is not None:
            weight_to_hook[name] = curr_param

    for epoch in range(0, args.epochs):
        # train for one epoch
        print('Timestamp Start epoch: {:%Y-%m-%d %H:%M:%S}'.format(
            datetime.datetime.now()))
        train(train_loader, model, criterion, optimizer, epoch, args, device,
              ml_logger, val_loader, mq, weight_to_hook)
        print('Timestamp End epoch: {:%Y-%m-%d %H:%M:%S}'.format(
            datetime.datetime.now()))

        if not args.lr_freeze:
            lr_scheduler.step()

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args, device)
        ml_logger.log_metric('Val Acc1', acc1, step='auto')

        # evaluate with k-means quantization
        # if args.model_freeze:
        # with mq.quantization_method('kmeans'):
        #     acc1_kmeans = validate(val_loader, model, criterion, args, device)
        #     ml_logger.log_metric('Val Acc1 kmeans', acc1_kmeans, epoch)

        # with mq.disable():
        #     acc1_nq = validate(val_loader, model, criterion, args, device)
        #     ml_logger.log_metric('Val Acc1 fp32', acc1_nq,  step='auto')

        if args.quantize:
            mq.log_quantizer_state(ml_logger, epoch)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        save_checkpoint(
            {
                'epoch':
                epoch + 1,
                'arch':
                args.arch,
                'state_dict':
                model.state_dict()
                if len(args.gpu_ids) == 1 else model.module.state_dict(),
                'best_acc1':
                best_acc1,
                'optimizer':
                optimizer.state_dict(),
            },
            is_best,
            datatime_str=datatime_str,
            suf_name=suf_name)
def main_worker(args, ml_logger):
    global best_acc1

    if args.gpu_ids is not None:
        print("Use GPU: {} for training".format(args.gpu_ids))

    # create model
    if 'resnet' in args.arch and args.custom_resnet:
        model = custom_resnet(arch=args.arch,
                              pretrained=args.pretrained,
                              depth=arch2depth(args.arch),
                              dataset=args.dataset)
    elif 'inception_v3' in args.arch and args.custom_inception:
        model = custom_inception(pretrained=args.pretrained)

    elif args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    device = torch.device('cuda:{}'.format(args.gpu_ids[0]))
    cudnn.benchmark = True

    torch.cuda.set_device(args.gpu_ids[0])
    model = model.to(device)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            # mq = ModelQuantizer(model, args)
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, device)
            args.start_epoch = checkpoint['epoch']
            if 'best_acc1' in checkpoint.keys():
                best_acc1 = checkpoint['best_acc1']
            else:
                best_acc1 = 0

            # best_acc1 = checkpoint['best_acc1']
            # best_acc1 may be from a checkpoint from a different GPU
            # best_acc1 = best_acc1.to(device)
            checkpoint['state_dict'] = {
                normalize_module_name(k): v
                for k, v in checkpoint['state_dict'].items()
            }
            model.load_state_dict(checkpoint['state_dict'], strict=False)
            # model.load_state_dict(checkpoint['state_dict'])
            # optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if len(args.gpu_ids) > 1:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features,
                                                   args.gpu_ids)
        else:
            model = torch.nn.DataParallel(model, args.gpu_ids)

    val_data = get_dataset(
        args.dataset,
        'val',
        get_transform(args.dataset,
                      augment=False,
                      scale_size=299 if 'inception' in args.arch else None,
                      input_size=299 if 'inception' in args.arch else None),
        datasets_path=args.datapath)
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=args.shuffle,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().to(device)

    if args.quantize:
        if args.bn_folding:
            print(
                "Applying batch-norm folding ahead of post-training quantization"
            )
            from utils.absorb_bn import search_absorbe_bn
            search_absorbe_bn(model)

        all_convs = [
            n for n, m in model.named_modules() if isinstance(m, nn.Conv2d)
        ]
        all_relu = [
            n for n, m in model.named_modules() if isinstance(m, nn.ReLU)
        ]
        all_relu6 = [
            n for n, m in model.named_modules() if isinstance(m, nn.ReLU6)
        ]
        layers = all_relu[1:-1] + all_relu6[1:-1] + all_convs[1:-1]
        replacement_factory = {
            nn.ReLU: ActivationModuleWrapperPost,
            nn.ReLU6: ActivationModuleWrapperPost,
            nn.Conv2d: ParameterModuleWrapperPost
        }
        mq = ModelQuantizer(model, args, layers, replacement_factory)
        mq.log_quantizer_state(ml_logger, -1)

    acc = validate(val_loader, model, criterion, args, device)
    ml_logger.log_metric('Val Acc1', acc, step='auto')
示例#10
0
    def load_maybe_calibrate(checkpoint):
        try:
            model.load_state_dict(checkpoint)
        except BaseException as e:
            if model_config.get('quantize'):
                measure_name = '{}-{}.measure'.format(args.model,
                                                      model_config['depth'])
                measure_path = os.path.join(save_path, measure_name)
                if os.path.exists(measure_path):
                    logging.info("loading checkpoint '%s'", args.resume)
                    checkpoint = torch.load(measure_path)
                    if 'state_dict' in checkpoint:
                        best_prec1 = checkpoint['best_prec1']
                        checkpoint = checkpoint['state_dict']
                        logging.info(
                            f"Measured checkpoint loaded, reference score top1 {best_prec1:.3f}"
                        )
                    model.load_state_dict(checkpoint)
                else:
                    if model_config.get('absorb_bn'):
                        from utils.absorb_bn import search_absorbe_bn
                        logging.info('absorbing batch normalization')
                        model_config.update({
                            'absorb_bn': False,
                            'quantize': False
                        })
                        model_bn = model_builder(**model_config)
                        model_bn.load_state_dict(checkpoint)
                        search_absorbe_bn(model_bn, verbose=True)
                        model_config.update({
                            'absorb_bn': True,
                            'quantize': True
                        })
                        checkpoint = model_bn.state_dict()
                    model.load_state_dict(checkpoint, strict=False)
                    logging.info("set model measure mode")
                    # set_bn_is_train(model,False)
                    set_measure_mode(model, True, logger=logging)
                    logging.info(
                        "calibrating apprentice model to get quant params")
                    model.to(args.device, dtype)
                    with torch.no_grad():
                        losses_avg, top1_avg, top5_avg = forward(
                            val_loader,
                            model,
                            criterion,
                            0,
                            training=False,
                            optimizer=None)
                    logging.info('Measured float resutls:\nLoss {loss:.4f}\t'
                                 'Prec@1 {top1:.3f}\t'
                                 'Prec@5 {top5:.3f}'.format(loss=losses_avg,
                                                            top1=top1_avg,
                                                            top5=top5_avg))
                    set_measure_mode(model, False, logger=logging)
                    # logging.info("test quant model accuracy")
                    # losses_avg, top1_avg, top5_avg = validate(val_loader, model, criterion, 0)
                    # logging.info('Quantized results:\nLoss {loss:.4f}\t'
                    #              'Prec@1 {top1:.3f}\t'
                    #              'Prec@5 {top5:.3f}'.format(loss=losses_avg, top1=top1_avg, top5=top5_avg))

                    save_checkpoint(
                        {
                            'epoch': 0,
                            'model': args.model,
                            'config': args.model_config,
                            'state_dict': model.state_dict(),
                            'best_prec1': top1_avg,
                            'regime': regime
                        },
                        True,
                        path=save_path,
                        save_all=True,
                        filename=measure_name)

            else:
                raise e
示例#11
0
def main():
    global args, best_prec1, dtype
    best_prec1 = 0
    args = parser.parse_args()
    dtype = torch_dtypes.get(args.dtype)
    torch.manual_seed(args.seed)
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    if args.evaluate:
        args.results_dir = '/tmp'
    if args.save is '':
        args.save = time_stamp
    save_path = os.path.join(args.results_dir, args.save)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    setup_logging(os.path.join(save_path, 'log.txt'),
                  resume=args.resume is not '')
    results_path = os.path.join(save_path, 'results')
    results = ResultsLog(results_path,
                         title='Training Results - %s' % args.save)

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)

    if 'cuda' in args.device and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.cuda.set_device(args.device_ids[0])
        cudnn.benchmark = True
    else:
        args.device_ids = None

    # create model
    logging.info("creating model %s", args.model)
    model_builder = models.__dict__[args.model]

    model_config = {
        'input_size': args.input_size,
        'dataset': args.dataset if args.dataset != 'imaginet' else 'imagenet'
    }
    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))
    model = model_builder(**model_config)
    model.to(args.device, dtype)

    # Data loading code
    default_transform = {
        'train':
        get_transform(args.dataset, input_size=args.input_size, augment=True),
        'eval':
        get_transform(args.dataset, input_size=args.input_size, augment=False)
    }
    logging.info("created model with configuration: %s", model_config)

    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    transform = getattr(model, 'input_transform', default_transform)
    regime = getattr(model, 'regime', [{
        'epoch': 0,
        'optimizer': args.optimizer,
        'lr': args.lr,
        'momentum': args.momentum,
        'weight_decay': args.weight_decay
    }])

    # define loss function (criterion) and optimizer
    criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)()
    criterion.to(args.device, dtype)
    train_data = get_dataset(args.dataset, 'train', transform['train'])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)
    val_data = get_dataset(args.dataset, 'val', transform['eval'])
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    def load_maybe_calibrate(checkpoint):
        try:
            model.load_state_dict(checkpoint)
        except BaseException as e:
            if model_config.get('quantize'):
                measure_name = '{}-{}.measure'.format(args.model,
                                                      model_config['depth'])
                measure_path = os.path.join(save_path, measure_name)
                if os.path.exists(measure_path):
                    logging.info("loading checkpoint '%s'", args.resume)
                    checkpoint = torch.load(measure_path)
                    if 'state_dict' in checkpoint:
                        best_prec1 = checkpoint['best_prec1']
                        checkpoint = checkpoint['state_dict']
                        logging.info(
                            f"Measured checkpoint loaded, reference score top1 {best_prec1:.3f}"
                        )
                    model.load_state_dict(checkpoint)
                else:
                    if model_config.get('absorb_bn'):
                        from utils.absorb_bn import search_absorbe_bn
                        logging.info('absorbing batch normalization')
                        model_config.update({
                            'absorb_bn': False,
                            'quantize': False
                        })
                        model_bn = model_builder(**model_config)
                        model_bn.load_state_dict(checkpoint)
                        search_absorbe_bn(model_bn, verbose=True)
                        model_config.update({
                            'absorb_bn': True,
                            'quantize': True
                        })
                        checkpoint = model_bn.state_dict()
                    model.load_state_dict(checkpoint, strict=False)
                    logging.info("set model measure mode")
                    # set_bn_is_train(model,False)
                    set_measure_mode(model, True, logger=logging)
                    logging.info(
                        "calibrating apprentice model to get quant params")
                    model.to(args.device, dtype)
                    with torch.no_grad():
                        losses_avg, top1_avg, top5_avg = forward(
                            val_loader,
                            model,
                            criterion,
                            0,
                            training=False,
                            optimizer=None)
                    logging.info('Measured float resutls:\nLoss {loss:.4f}\t'
                                 'Prec@1 {top1:.3f}\t'
                                 'Prec@5 {top5:.3f}'.format(loss=losses_avg,
                                                            top1=top1_avg,
                                                            top5=top5_avg))
                    set_measure_mode(model, False, logger=logging)
                    # logging.info("test quant model accuracy")
                    # losses_avg, top1_avg, top5_avg = validate(val_loader, model, criterion, 0)
                    # logging.info('Quantized results:\nLoss {loss:.4f}\t'
                    #              'Prec@1 {top1:.3f}\t'
                    #              'Prec@5 {top5:.3f}'.format(loss=losses_avg, top1=top1_avg, top5=top5_avg))

                    save_checkpoint(
                        {
                            'epoch': 0,
                            'model': args.model,
                            'config': args.model_config,
                            'state_dict': model.state_dict(),
                            'best_prec1': top1_avg,
                            'regime': regime
                        },
                        True,
                        path=save_path,
                        save_all=True,
                        filename=measure_name)

            else:
                raise e

    # optionally resume from a checkpoint
    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            parser.error('invalid checkpoint: {}'.format(args.evaluate))
        checkpoint = torch.load(args.evaluate)
        # model.load_state_dict(checkpoint['state_dict'])
        # logging.info("loaded checkpoint '%s' (epoch %s)",
        #              args.evaluate, checkpoint['epoch'])
        load_maybe_calibrate(checkpoint)
    elif args.resume:
        checkpoint_file = args.resume
        if os.path.isdir(checkpoint_file):
            results.load(os.path.join(checkpoint_file, 'results.csv'))
            checkpoint_file = os.path.join(checkpoint_file,
                                           'model_best.pth.tar')
        if os.path.isfile(checkpoint_file):
            logging.info("loading checkpoint '%s'", args.resume)
            checkpoint = torch.load(checkpoint_file)
            if 'state_dict' in checkpoint:
                if checkpoint['epoch'] > 0:
                    args.start_epoch = checkpoint['epoch'] - 1
                best_prec1 = checkpoint['best_prec1']
                checkpoint = checkpoint['state_dict']

            try:
                model.load_state_dict(checkpoint)
            except BaseException as e:
                if model_config.get('quantize'):
                    if model_config.get('absorb_bn'):
                        from utils.absorb_bn import search_absorbe_bn
                        logging.info('absorbing batch normalization')
                        model_config.update({
                            'absorb_bn': False,
                            'quantize': False
                        })
                        model_bn = model_builder(**model_config)
                        model_bn.load_state_dict(checkpoint)
                        search_absorbe_bn(model_bn, verbose=True)
                        model_config.update({
                            'absorb_bn': True,
                            'quantize': True
                        })
                        checkpoint = model_bn.state_dict()
                    model.load_state_dict(checkpoint, strict=False)
                    model.to(args.device, dtype)
                    logging.info("set model measure mode")
                    # set_bn_is_train(model,False)
                    set_measure_mode(model, True, logger=logging)
                    logging.info(
                        "calibrating apprentice model to get quant params")
                    model.to(args.device, dtype)
                    with torch.no_grad():
                        losses_avg, top1_avg, top5_avg = forward(
                            val_loader,
                            model,
                            criterion,
                            0,
                            training=False,
                            optimizer=None)
                    logging.info('Measured float resutls:\nLoss {loss:.4f}\t'
                                 'Prec@1 {top1:.3f}\t'
                                 'Prec@5 {top5:.3f}'.format(loss=losses_avg,
                                                            top1=top1_avg,
                                                            top5=top5_avg))
                    set_measure_mode(model, False, logger=logging)
                    logging.info("test quant model accuracy")
                    losses_avg, top1_avg, top5_avg = validate(
                        val_loader, model, criterion, 0)
                    logging.info('Quantized results:\nLoss {loss:.4f}\t'
                                 'Prec@1 {top1:.3f}\t'
                                 'Prec@5 {top5:.3f}'.format(loss=losses_avg,
                                                            top1=top1_avg,
                                                            top5=top5_avg))
                    save_checkpoint(
                        {
                            'epoch': 0,
                            'model': args.model,
                            'config': args.model_config,
                            'state_dict': model.state_dict(),
                            'best_prec1': top1_avg,
                            'regime': regime
                        },
                        True,
                        path=save_path,
                        save_freq=5)
                    #save_checkpoint(model.state_dict(), is_best=True, path=save_path, save_all=True)
                    logging.info(
                        f'overwriting quantization method with {args.q_method}'
                    )
                    set_global_quantization_method(model, args.q_method)
                else:
                    raise e

            logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file,
                         args.start_epoch)
        else:
            logging.error("no checkpoint found at '%s'", args.resume)
    if args.evaluate:
        if model_config.get('quantize'):
            logging.info(
                f'overwriting quantization method with {args.q_method}')
            set_global_quantization_method(model, args.q_method)
        losses_avg, top1_avg, top5_avg = validate(val_loader, model, criterion,
                                                  0)
        logging.info('Evaluation results:\nLoss {loss:.4f}\t'
                     'Prec@1 {top1:.3f}\t'
                     'Prec@5 {top5:.3f}'.format(loss=losses_avg,
                                                top1=top1_avg,
                                                top5=top5_avg))
        return

    optimizer = OptimRegime(model, regime)
    logging.info('training regime: %s', regime)

    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        train_loss, train_prec1, train_prec5 = train(train_loader, model,
                                                     criterion, epoch,
                                                     optimizer)

        # evaluate on validation set
        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion,
                                                  epoch)

        # remember best prec@1 and save checkpoint
        is_best = val_prec1 > best_prec1
        best_prec1 = max(val_prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': args.model,
                'config': args.model_config,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'regime': regime
            },
            is_best,
            path=save_path)
        logging.info('\n Epoch: {0}\t'
                     'Training Loss {train_loss:.4f} \t'
                     'Training Prec@1 {train_prec1:.3f} \t'
                     'Training Prec@5 {train_prec5:.3f} \t'
                     'Validation Loss {val_loss:.4f} \t'
                     'Validation Prec@1 {val_prec1:.3f} \t'
                     'Validation Prec@5 {val_prec5:.3f} \n'.format(
                         epoch + 1,
                         train_loss=train_loss,
                         val_loss=val_loss,
                         train_prec1=train_prec1,
                         val_prec1=val_prec1,
                         train_prec5=train_prec5,
                         val_prec5=val_prec5))

        results.add(epoch=epoch + 1,
                    train_loss=train_loss,
                    val_loss=val_loss,
                    train_error1=100 - train_prec1,
                    val_error1=100 - val_prec1,
                    train_error5=100 - train_prec5,
                    val_error5=100 - val_prec5)
        results.plot(x='epoch',
                     y=['train_loss', 'val_loss'],
                     legend=['training', 'validation'],
                     title='Loss',
                     ylabel='loss')
        results.plot(x='epoch',
                     y=['train_error1', 'val_error1'],
                     legend=['training', 'validation'],
                     title='Error@1',
                     ylabel='error %')
        results.plot(x='epoch',
                     y=['train_error5', 'val_error5'],
                     legend=['training', 'validation'],
                     title='Error@5',
                     ylabel='error %')
        results.save()
示例#12
0
def main_ratio(args, ml_logger):
    # Fix the seed
    random.seed(args.seed)
    if not args.dont_fix_np_seed:
        np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    curr_best_acc = 0
    curr_best_scale_point = None

    args.qtype = 'max_static'
    # create model
    # Always enable shuffling to avoid issues where we get bad results due to weak statistics
    custom_resnet = True
    custom_inception = True
    inf_model = CnnModel(args.arch,
                         custom_resnet,
                         custom_inception,
                         args.pretrained,
                         args.dataset,
                         args.gpu_ids,
                         args.datapath,
                         batch_size=args.batch_size,
                         shuffle=True,
                         workers=args.workers,
                         print_freq=args.print_freq,
                         cal_batch_size=args.cal_batch_size,
                         cal_set_size=args.cal_set_size,
                         args=args)

    # pdb.set_trace()
    if args.bn_folding:
        print(
            "Applying batch-norm folding ahead of post-training quantization")
        # pdb.set_trace()
        from utils.absorb_bn import search_absorbe_bn
        search_absorbe_bn(inf_model.model)
    # pdb.set_trace()

    layers = []
    # TODO: make it more generic
    if args.bit_weights is not None:
        layers += [
            n for n, m in inf_model.model.named_modules()
            if isinstance(m, nn.Conv2d)
        ][1:-1]
    if args.bit_act is not None:
        layers += [
            n for n, m in inf_model.model.named_modules()
            if isinstance(m, nn.ReLU)
        ][1:-1]
    if args.bit_act is not None and 'mobilenet' in args.arch:
        layers += [
            n for n, m in inf_model.model.named_modules()
            if isinstance(m, nn.ReLU6)
        ][1:-1]

    replacement_factory = {
        nn.ReLU: ActivationModuleWrapperPost,
        nn.ReLU6: ActivationModuleWrapperPost,
        nn.Conv2d: ParameterModuleWrapperPost
    }

    mq = ModelQuantizer(inf_model.model, args, layers, replacement_factory)
    loss = inf_model.evaluate_calibration()

    # evaluate
    max_acc = inf_model.validate()
    max_point = mq.get_clipping()
    # pdb.set_trace()
    if max_acc > curr_best_acc:
        curr_best_acc = max_acc
        curr_best_scale_point = max_point
    ml_logger.log_metric('Loss max', loss.item(), step='auto')
    ml_logger.log_metric('Acc max', max_acc, step='auto')
    data = {'max': {'alpha': max_point.cpu().numpy(), 'loss': loss.item()}}
    print("max loss: {:.4f}, max_acc: {:.4f}".format(loss.item(), max_acc))

    def eval_pnorm(p):
        args.qtype = 'lp_norm'
        args.lp = p
        # Fix the seed
        random.seed(args.seed)
        if not args.dont_fix_np_seed:
            np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        inf_model = CnnModel(args.arch,
                             custom_resnet,
                             custom_inception,
                             args.pretrained,
                             args.dataset,
                             args.gpu_ids,
                             args.datapath,
                             batch_size=args.batch_size,
                             shuffle=True,
                             workers=args.workers,
                             print_freq=args.print_freq,
                             cal_batch_size=args.cal_batch_size,
                             cal_set_size=args.cal_set_size,
                             args=args)

        if args.bn_folding:
            print(
                "Applying batch-norm folding ahead of post-training quantization"
            )
            # pdb.set_trace()
            from utils.absorb_bn import search_absorbe_bn
            search_absorbe_bn(inf_model.model)
        mq = ModelQuantizer(inf_model.model, args, layers, replacement_factory)
        loss = inf_model.evaluate_calibration()
        point = mq.get_clipping()

        # evaluate
        acc = inf_model.validate()

        del inf_model
        del mq

        return point, loss, acc

    del inf_model
    del mq

    l2_point, l2_loss, l2_acc = eval_pnorm(2.)
    print("loss l2: {:.4f}".format(l2_loss.item()))
    ml_logger.log_metric('Loss l2', l2_loss.item(), step='auto')
    ml_logger.log_metric('Acc l2', l2_acc, step='auto')
    data['l2'] = {
        'alpha': l2_point.cpu().numpy(),
        'loss': l2_loss.item(),
        'acc': l2_acc
    }
    if l2_acc > curr_best_acc:
        curr_best_acc = l2_acc
        curr_best_scale_point = l2_point

    l25_point, l25_loss, l25_acc = eval_pnorm(2.5)
    print("loss l2.5: {:.4f}".format(l25_loss.item()))
    ml_logger.log_metric('Loss l2.5', l25_loss.item(), step='auto')
    ml_logger.log_metric('Acc l2.5', l25_acc, step='auto')
    data['l2.5'] = {
        'alpha': l25_point.cpu().numpy(),
        'loss': l25_loss.item(),
        'acc': l25_acc
    }
    if l25_acc > curr_best_acc:
        curr_best_acc = l25_acc
        curr_best_scale_point = l25_point

    l3_point, l3_loss, l3_acc = eval_pnorm(3.)
    print("loss l3: {:.4f}".format(l3_loss.item()))
    ml_logger.log_metric('Loss l3', l3_loss.item(), step='auto')
    ml_logger.log_metric('Acc l3', l3_acc, step='auto')
    data['l3'] = {
        'alpha': l3_point.cpu().numpy(),
        'loss': l3_loss.item(),
        'acc': l3_acc
    }
    if l3_acc > curr_best_acc:
        curr_best_acc = l3_acc
        curr_best_scale_point = l3_point

    # Interpolate optimal p
    xp = np.linspace(1, 5, 50)
    z = np.polyfit([2, 2.5, 3], [l2_acc, l25_acc, l3_acc], 2)
    y = np.poly1d(z)
    p_intr = xp[np.argmax(y(xp))]
    print("p intr: {:.2f}".format(p_intr))
    ml_logger.log_metric('p intr', p_intr, step='auto')

    args.qtype = 'lp_norm'
    args.lp = p_intr
    # Fix the seed
    random.seed(args.seed)
    if not args.dont_fix_np_seed:
        np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    inf_model = CnnModel(args.arch,
                         custom_resnet,
                         custom_inception,
                         args.pretrained,
                         args.dataset,
                         args.gpu_ids,
                         args.datapath,
                         batch_size=args.batch_size,
                         shuffle=True,
                         workers=args.workers,
                         print_freq=args.print_freq,
                         cal_batch_size=args.cal_batch_size,
                         cal_set_size=args.cal_set_size,
                         args=args)

    if args.bn_folding:
        print(
            "Applying batch-norm folding ahead of post-training quantization")
        # pdb.set_trace()
        from utils.absorb_bn import search_absorbe_bn
        search_absorbe_bn(inf_model.model)
    mq = ModelQuantizer(inf_model.model, args, layers, replacement_factory)

    # Evaluate with optimal p
    lp_loss = inf_model.evaluate_calibration()
    lp_point = mq.get_clipping()
    # evaluate
    lp_acc = inf_model.validate()

    print("loss p intr: {:.4f}".format(lp_loss.item()))
    ml_logger.log_metric('Loss p intr', lp_loss.item(), step='auto')
    ml_logger.log_metric('Acc p intr', lp_acc, step='auto')
    if lp_acc > curr_best_acc:
        curr_best_acc = lp_acc
        curr_best_scale_point = lp_point

    global _eval_count, _min_loss
    _min_loss = lp_loss.item()

    idx = np.argmax([l2_acc, l25_acc, l3_acc, lp_acc])
    init = [l2_point, l25_point, l3_point, lp_point][idx]

    # run optimizer
    min_options = {}
    if args.maxiter is not None:
        min_options['maxiter'] = args.maxiter
    if args.maxfev is not None:
        min_options['maxfev'] = args.maxfev

    _iter = count(0)

    def local_search_callback(x):
        it = next(_iter)
        mq.set_clipping(x, inf_model.device)
        loss = inf_model.evaluate_calibration()
        print("\n[{}]: Local search callback".format(it))
        print("loss: {:.4f}\n".format(loss.item()))
        print(x)
        ml_logger.log_metric('Loss {}'.format(args.min_method),
                             loss.item(),
                             step='auto')

        # evaluate
        acc = inf_model.validate()
        ml_logger.log_metric('Acc {}'.format(args.min_method),
                             acc,
                             step='auto')

    args.min_method = "Powell"
    method = coord_descent if args.min_method == 'CD' else args.min_method
    res = opt.minimize(
        lambda scales: evaluate_calibration_clipped(scales, inf_model, mq),
        init.cpu().numpy(),
        method=method,
        options=min_options,
        callback=local_search_callback)

    print(res)

    scales = res.x
    mq.set_clipping(scales, inf_model.device)
    loss = inf_model.evaluate_calibration()
    ml_logger.log_metric('Loss {}'.format(args.min_method),
                         loss.item(),
                         step='auto')

    # evaluate
    acc = inf_model.validate()
    ml_logger.log_metric('Acc {}'.format(args.min_method), acc, step='auto')
    data['powell'] = {'alpha': scales, 'loss': loss.item(), 'acc': acc}
    if acc > curr_best_acc:
        curr_best_acc = acc
        curr_best_scale_point = scales

    print("Starting coordinate descent")
    args.min_method = "CD"
    _iter = count(0)
    global _eval_count
    _eval_count = count(0)
    _min_loss = lp_loss.item()
    mq.set_clipping(init, inf_model.device)
    # Run coordinate descent for comparison
    method = coord_descent
    res = opt.minimize(
        lambda scales: evaluate_calibration_clipped(scales, inf_model, mq),
        init.cpu().numpy(),
        method=method,
        options=min_options,
        callback=local_search_callback)

    print(res)

    scales = res.x
    mq.set_clipping(scales, inf_model.device)
    loss = inf_model.evaluate_calibration()
    ml_logger.log_metric('Loss {}'.format("CD"), loss.item(), step='auto')

    # evaluate
    acc = inf_model.validate()
    ml_logger.log_metric('Acc {}'.format("CD"), acc, step='auto')
    data['cd'] = {'alpha': scales, 'loss': loss.item(), 'acc': acc}
    if acc > curr_best_acc:
        curr_best_acc = acc
        curr_best_scale_point = scales

    pdb.set_trace()
    if curr_best_scale_point.is_cuda:
        curr_best_scale_point = curr_best_scale_point.cpu()
    best_point = np.concatenate(
        [curr_best_scale_point,
         torch.tensor([curr_best_acc])])
    print("**** START LOSS GENERATION ****")
    print("best point:" + str(best_point))
    best_point_values = best_point[:-1]
    mq.set_clipping(best_point_values, inf_model.device)
    loss = inf_model.evaluate_calibration()
    # evaluate
    top1 = inf_model.validate()
    print("best point: loss, top1: {:.4f}, {}".format(loss.item(), top1))

    # best_point = curr_best_scale_point
    # best_point = mq.get_clipping()
    # best_point_values = curr_best_scale_point[:-1]
    # pdb.set_trace()
    n = args.grid_resolution

    min_ratio = args.min_ratio  # 0.8
    max_ratio = args.max_ratio  # 1.2

    x = np.linspace(min_ratio, max_ratio, n)
    # y = np.linspace(min_ratio, max_ratio, n)

    loss_best = loss
    # X, Y = np.meshgrid(x, y)
    Z_loss = np.empty(n)
    Z_top1 = np.empty(n)
    for i, x_ in enumerate(tqdm(x)):
        # set clip value to qwrappers
        scales_ratio = x_
        mq.set_clipping((best_point_values * scales_ratio), inf_model.device)

        if scales_ratio == 1.0:
            print(best_point_values * scales_ratio)
        # evaluate with clipping
        loss = inf_model.evaluate_calibration()
        Z_loss[i] = loss.item()
        Z_top1[i] = inf_model.validate()

        str1 = "[x, loss, top1] = [{}, {}, {}]".format(x[i], Z_loss[i],
                                                       Z_top1[i])
        print(str1)

    # pdb.set_trace()
    # best_point = np.concatenate([1.0, loss_best.cpu().numpy()])
    best_point_ratio = [1.0, loss_best.cpu().numpy()]
    print("best_point_ratio: " + str(best_point_ratio))
    # best_point = [best_point_values, loss_best.cpu().numpy()]
    # print("best point: " + str(best_point))
    print("best point values: " + str(best_point_values))

    f_name = "loss_generation_lapq_{}_W{}A{}.pkl".format(
        args.arch, 'ALL', None)
    dir_fullname = os.path.join(os.getcwd(), args.experiment)
    if not os.path.exists(dir_fullname):
        os.makedirs(dir_fullname)
    f = open(os.path.join(dir_fullname, f_name), 'wb')
    data = {
        'X': x,
        'Z_loss': Z_loss,
        'Z_top1': Z_top1,
        'best_point_ratio': best_point_ratio,
        'best_point': best_point_values
    }
    pickle.dump(data, f)
    f.close()
    print("Data saved to {}".format(f_name))
示例#13
0
    def __init__(self, arch, use_custom_resnet, use_custom_inception,
                 pretrained, dataset, gpu_ids, datapath, batch_size, shuffle,
                 workers, print_freq, cal_batch_size, cal_set_size, args):
        self.arch = arch
        self.use_custom_resnet = use_custom_resnet
        self.pretrained = pretrained
        self.dataset = dataset
        self.gpu_ids = gpu_ids
        self.datapath = datapath
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.workers = workers
        self.print_freq = print_freq
        self.cal_batch_size = cal_batch_size
        self.cal_set_size = cal_set_size  # TODO: pass it as cmd line argument

        # create model
        if 'resnet' in arch and use_custom_resnet:
            model = custom_resnet(arch=arch,
                                  pretrained=pretrained,
                                  depth=arch2depth(arch),
                                  dataset=dataset)
        elif 'inception_v3' in arch and use_custom_inception:
            model = custom_inception(pretrained=pretrained)
        else:
            print("=> using pre-trained model '{}'".format(arch))
            model = models.__dict__[arch](pretrained=pretrained)

        self.device = torch.device('cuda:{}'.format(gpu_ids[0]))

        torch.cuda.set_device(gpu_ids[0])
        model = model.to(self.device)

        # optionally resume from a checkpoint
        if args.resume:
            if os.path.isfile(args.resume):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(args.resume, self.device)
                args.start_epoch = checkpoint['epoch']
                checkpoint['state_dict'] = {
                    normalize_module_name(k): v
                    for k, v in checkpoint['state_dict'].items()
                }
                model.load_state_dict(checkpoint['state_dict'], strict=False)
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(args.resume))

        if len(gpu_ids) > 1:
            # DataParallel will divide and allocate batch_size to all available GPUs
            if arch.startswith('alexnet') or arch.startswith('vgg'):
                model.features = torch.nn.DataParallel(model.features, gpu_ids)
            else:
                model = torch.nn.DataParallel(model, gpu_ids)

        self.model = model

        if args.bn_folding:
            print(
                "Applying batch-norm folding ahead of post-training quantization"
            )
            from utils.absorb_bn import search_absorbe_bn
            search_absorbe_bn(model)

        # define loss function (criterion) and optimizer
        self.criterion = torch.nn.CrossEntropyLoss().to(self.device)

        val_data = get_dataset(
            dataset,
            'val',
            get_transform(dataset,
                          augment=False,
                          scale_size=299 if 'inception' in arch else None,
                          input_size=299 if 'inception' in arch else None),
            datasets_path=datapath)
        self.val_loader = torch.utils.data.DataLoader(val_data,
                                                      batch_size=batch_size,
                                                      shuffle=shuffle,
                                                      num_workers=workers,
                                                      pin_memory=True)

        self.cal_loader = torch.utils.data.DataLoader(
            val_data,
            batch_size=self.cal_batch_size,
            shuffle=shuffle,
            num_workers=workers,
            pin_memory=True)
示例#14
0
def main_worker(args):
    global best_prec1, dtype
    best_prec1 = 0
    dtype = torch_dtypes.get(args.dtype)
    torch.manual_seed(args.seed)
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    if args.evaluate:
        args.results_dir = '/tmp'
    if args.save is '':
        args.save = time_stamp
    save_path = os.path.join(args.results_dir, args.save)

    args.distributed = args.local_rank >= 0 or args.world_size > 1

    if not os.path.exists(save_path) and not (args.distributed
                                              and args.local_rank > 0):
        os.makedirs(save_path)

    setup_logging(os.path.join(save_path, 'log.txt'),
                  resume=args.resume is not '',
                  dummy=args.distributed and args.local_rank > 0)

    results_path = os.path.join(save_path, 'results')
    results = ResultsLog(results_path,
                         title='Training Results - %s' % args.save)

    if 'cuda' in args.device and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.cuda.set_device(args.device_ids[0])
        cudnn.benchmark = True
    else:
        args.device_ids = None

    if not os.path.isfile(args.evaluate):
        parser.error('invalid checkpoint: {}'.format(args.evaluate))
    checkpoint = torch.load(args.evaluate, map_location="cpu")
    # Overrride configuration with checkpoint info
    args.model = checkpoint.get('model', args.model)
    args.model_config = checkpoint.get('config', args.model_config)

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)
    logging.info("creating model %s", args.model)

    # create model
    model = models.__dict__[args.model]
    model_config = {'dataset': args.dataset}

    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))

    model = model(**model_config)
    logging.info("created model with configuration: %s", model_config)
    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    # load checkpoint
    model.load_state_dict(checkpoint['state_dict'])
    logging.info("loaded checkpoint '%s' (epoch %s)", args.evaluate,
                 checkpoint['epoch'])

    if args.absorb_bn:
        search_absorbe_bn(model, verbose=True)

    # define loss function (criterion) and optimizer
    loss_params = {}
    if args.label_smoothing > 0:
        loss_params['smooth_eps'] = args.label_smoothing
    criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)(**loss_params)
    criterion.to(args.device, dtype)
    model.to(args.device, dtype)

    # Batch-norm should always be done in float
    if 'half' in args.dtype:
        FilterModules(model, module=is_bn).to(dtype=torch.float)

    trainer = Trainer(model,
                      criterion,
                      device_ids=args.device_ids,
                      device=args.device,
                      dtype=dtype,
                      mixup=args.mixup,
                      print_freq=args.print_freq)

    # Evaluation Data loading code
    val_data = DataRegime(getattr(model, 'data_eval_regime', None),
                          defaults={
                              'datasets_path': args.datasets_dir,
                              'name': args.dataset,
                              'split': 'val',
                              'augment': args.augment,
                              'input_size': args.input_size,
                              'batch_size': args.batch_size,
                              'shuffle': False,
                              'duplicates': args.duplicates,
                              'autoaugment': args.autoaugment,
                              'cutout': {
                                  'holes': 1,
                                  'length': 16
                              } if args.cutout else None,
                              'num_workers': args.workers,
                              'pin_memory': True,
                              'drop_last': False
                          })

    results = trainer.validate(val_data.get_loader(),
                               duplicates=val_data.get('duplicates'),
                               average_output=args.avg_out)
    logging.info(results)
    return results