def main_worker(args, ml_logger):
    global best_acc1

    if args.gpu_ids is not None:
        print("Use GPU: {} for training".format(args.gpu_ids))

    if args.log_stats:
        from utils.stats_trucker import StatsTrucker as ST
        ST("W{}A{}".format(args.bit_weights, args.bit_act))

    if 'resnet' in args.arch and args.custom_resnet:
        model = custom_resnet(arch=args.arch,
                              pretrained=args.pretrained,
                              depth=arch2depth(args.arch),
                              dataset=args.dataset)
    elif 'inception_v3' in args.arch and args.custom_inception:
        model = custom_inception(pretrained=args.pretrained)
    else:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=args.pretrained)

    device = torch.device('cuda:{}'.format(args.gpu_ids[0]))
    cudnn.benchmark = True

    torch.cuda.set_device(args.gpu_ids[0])
    model = model.to(device)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, device)
            args.start_epoch = checkpoint['epoch']
            # best_acc1 = checkpoint['best_acc1']
            # best_acc1 may be from a checkpoint from a different GPU
            # best_acc1 = best_acc1.to(device)
            checkpoint['state_dict'] = {
                normalize_module_name(k): v
                for k, v in checkpoint['state_dict'].items()
            }
            model.load_state_dict(checkpoint['state_dict'], strict=False)
            # optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if len(args.gpu_ids) > 1:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features,
                                                   args.gpu_ids)
        else:
            model = torch.nn.DataParallel(model, args.gpu_ids)

    default_transform = {
        'train': get_transform(args.dataset, augment=True),
        'eval': get_transform(args.dataset, augment=False)
    }

    val_data = get_dataset(args.dataset, 'val', default_transform['eval'])
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().to(device)

    train_data = get_dataset(args.dataset, 'train', default_transform['train'])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    # TODO: replace this call by initialization on small subset of training data
    # TODO: enable for activations
    # validate(val_loader, model, criterion, args, device)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    lr_scheduler = StepLR(optimizer, step_size=args.lr_step, gamma=0.1)

    mq = None
    if args.quantize:
        if args.bn_folding:
            print(
                "Applying batch-norm folding ahead of post-training quantization"
            )
            from utils.absorb_bn import search_absorbe_bn
            search_absorbe_bn(model)

        all_convs = [
            n for n, m in model.named_modules() if isinstance(m, nn.Conv2d)
        ]
        # all_convs = [l for l in all_convs if 'downsample' not in l]
        all_relu = [
            n for n, m in model.named_modules() if isinstance(m, nn.ReLU)
        ]
        all_relu6 = [
            n for n, m in model.named_modules() if isinstance(m, nn.ReLU6)
        ]
        layers = all_relu[1:-1] + all_relu6[1:-1] + all_convs[1:]
        replacement_factory = {
            nn.ReLU: ActivationModuleWrapper,
            nn.ReLU6: ActivationModuleWrapper,
            nn.Conv2d: ParameterModuleWrapper
        }
        mq = ModelQuantizer(
            model, args, layers, replacement_factory,
            OptimizerBridge(optimizer,
                            settings={
                                'algo': 'SGD',
                                'dataset': args.dataset
                            }))

        if args.resume:
            # Load quantization parameters from state dict
            mq.load_state_dict(checkpoint['state_dict'])

        mq.log_quantizer_state(ml_logger, -1)

        if args.model_freeze:
            mq.freeze()

    if args.evaluate:
        if args.log_stats:
            mean = []
            var = []
            skew = []
            kurt = []
            for n, p in model.named_parameters():
                if n.replace('.weight', '') in all_convs[1:]:
                    mu = p.mean()
                    std = p.std()
                    mean.append((n, mu.item()))
                    var.append((n, (std**2).item()))
                    skew.append((n, torch.mean(((p - mu) / std)**3).item()))
                    kurt.append((n, torch.mean(((p - mu) / std)**4).item()))
            for i in range(len(mean)):
                ml_logger.log_metric(mean[i][0] + '.mean', mean[i][1])
                ml_logger.log_metric(var[i][0] + '.var', var[i][1])
                ml_logger.log_metric(skew[i][0] + '.skewness', skew[i][1])
                ml_logger.log_metric(kurt[i][0] + '.kurtosis', kurt[i][1])

            ml_logger.log_metric('weight_mean', np.mean([s[1] for s in mean]))
            ml_logger.log_metric('weight_var', np.mean([s[1] for s in var]))
            ml_logger.log_metric('weight_skewness',
                                 np.mean([s[1] for s in skew]))
            ml_logger.log_metric('weight_kurtosis',
                                 np.mean([s[1] for s in kurt]))

        acc = validate(val_loader, model, criterion, args, device)
        ml_logger.log_metric('Val Acc1', acc)
        if args.log_stats:
            stats = ST().get_stats()
            for s in stats:
                ml_logger.log_metric(s, np.mean(stats[s]))
        return

    # evaluate on validation set
    acc1 = validate(val_loader, model, criterion, args, device)
    ml_logger.log_metric('Val Acc1', acc1, -1)

    # evaluate with k-means quantization
    # if args.model_freeze:
    # with mq.disable():
    #     acc1_nq = validate(val_loader, model, criterion, args, device)
    #     ml_logger.log_metric('Val Acc1 fp32', acc1_nq, -1)

    for epoch in range(0, args.epochs):
        # train for one epoch
        print('Timestamp Start epoch: {:%Y-%m-%d %H:%M:%S}'.format(
            datetime.datetime.now()))
        train(train_loader, model, criterion, optimizer, epoch, args, device,
              ml_logger, val_loader, mq)
        print('Timestamp End epoch: {:%Y-%m-%d %H:%M:%S}'.format(
            datetime.datetime.now()))

        if not args.lr_freeze:
            lr_scheduler.step()

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args, device)
        ml_logger.log_metric('Val Acc1', acc1, step='auto')

        # evaluate with k-means quantization
        # if args.model_freeze:
        # with mq.quantization_method('kmeans'):
        #     acc1_kmeans = validate(val_loader, model, criterion, args, device)
        #     ml_logger.log_metric('Val Acc1 kmeans', acc1_kmeans, epoch)

        # with mq.disable():
        #     acc1_nq = validate(val_loader, model, criterion, args, device)
        #     ml_logger.log_metric('Val Acc1 fp32', acc1_nq,  step='auto')

        if args.quantize:
            mq.log_quantizer_state(ml_logger, epoch)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        save_checkpoint(
            {
                'epoch':
                epoch + 1,
                'arch':
                args.arch,
                'state_dict':
                model.state_dict()
                if len(args.gpu_ids) == 1 else model.module.state_dict(),
                'best_acc1':
                best_acc1,
                'optimizer':
                optimizer.state_dict(),
            }, is_best)
Пример #2
0
def main_worker(args, ml_logger):
    global best_acc1
    datatime_str = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    suf_name = "_" + args.experiment

    if args.gpu_ids is not None:
        print("Use GPU: {} for training".format(args.gpu_ids))

    if args.log_stats:
        from utils.stats_trucker import StatsTrucker as ST
        ST("W{}A{}".format(args.bit_weights, args.bit_act))

    if 'resnet' in args.arch and args.custom_resnet:
        # pdb.set_trace()
        model = custom_resnet(arch=args.arch,
                              pretrained=args.pretrained,
                              depth=arch2depth(args.arch),
                              dataset=args.dataset)
    elif 'inception_v3' in args.arch and args.custom_inception:
        model = custom_inception(pretrained=args.pretrained)
    else:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=args.pretrained)

    device = torch.device('cuda:{}'.format(args.gpu_ids[0]))
    cudnn.benchmark = True

    torch.cuda.set_device(args.gpu_ids[0])
    model = model.to(device)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, device)
            args.start_epoch = checkpoint['epoch']
            # best_acc1 = checkpoint['best_acc1']
            # best_acc1 may be from a checkpoint from a different GPU
            # best_acc1 = best_acc1.to(device)
            checkpoint['state_dict'] = {
                normalize_module_name(k): v
                for k, v in checkpoint['state_dict'].items()
            }
            model.load_state_dict(checkpoint['state_dict'], strict=False)
            # optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if len(args.gpu_ids) > 1:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features,
                                                   args.gpu_ids)
        else:
            model = torch.nn.DataParallel(model, args.gpu_ids)

    default_transform = {
        'train': get_transform(args.dataset, augment=True),
        'eval': get_transform(args.dataset, augment=False)
    }

    val_data = get_dataset(args.dataset, 'val', default_transform['eval'])
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().to(device)

    train_data = get_dataset(args.dataset, 'train', default_transform['train'])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    # TODO: replace this call by initialization on small subset of training data
    # TODO: enable for activations
    # validate(val_loader, model, criterion, args, device)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    lr_scheduler = StepLR(optimizer, step_size=args.lr_step, gamma=0.1)

    # pdb.set_trace()
    mq = None
    if args.quantize:
        if args.bn_folding:
            print(
                "Applying batch-norm folding ahead of post-training quantization"
            )
            from utils.absorb_bn import search_absorbe_bn
            search_absorbe_bn(model)

        all_convs = [
            n for n, m in model.named_modules() if isinstance(m, nn.Conv2d)
        ]
        # all_convs = [l for l in all_convs if 'downsample' not in l]
        all_relu = [
            n for n, m in model.named_modules() if isinstance(m, nn.ReLU)
        ]
        all_relu6 = [
            n for n, m in model.named_modules() if isinstance(m, nn.ReLU6)
        ]
        layers = all_relu[1:-1] + all_relu6[1:-1] + all_convs[1:]
        replacement_factory = {
            nn.ReLU: ActivationModuleWrapper,
            nn.ReLU6: ActivationModuleWrapper,
            nn.Conv2d: ParameterModuleWrapper
        }
        mq = ModelQuantizer(
            model, args, layers, replacement_factory,
            OptimizerBridge(optimizer,
                            settings={
                                'algo': 'SGD',
                                'dataset': args.dataset
                            }))

        if args.resume:
            # Load quantization parameters from state dict
            mq.load_state_dict(checkpoint['state_dict'])

        mq.log_quantizer_state(ml_logger, -1)

        if args.model_freeze:
            mq.freeze()

    # pdb.set_trace()
    if args.evaluate:
        acc = validate(val_loader, model, criterion, args, device)
        ml_logger.log_metric('Val Acc1', acc)
        return

    # evaluate on validation set
    acc1 = validate(val_loader, model, criterion, args, device)
    ml_logger.log_metric('Val Acc1', acc1, -1)

    # evaluate with k-means quantization
    # if args.model_freeze:
    # with mq.disable():
    #     acc1_nq = validate(val_loader, model, criterion, args, device)
    #     ml_logger.log_metric('Val Acc1 fp32', acc1_nq, -1)

    # pdb.set_trace()
    # Kurtosis regularization on weights tensors
    weight_to_hook = {}
    if args.w_kurtosis:
        if args.weight_name[0] == 'all':
            all_convs = [
                n.replace(".wrapped_module", "") + '.weight'
                for n, m in model.named_modules() if isinstance(m, nn.Conv2d)
            ]
            weight_name = all_convs[1:]
            if args.remove_weight_name:
                for rm_name in args.remove_weight_name:
                    weight_name.remove(rm_name)
        else:
            weight_name = args.weight_name
        for name in weight_name:
            # pdb.set_trace()
            curr_param = fine_weight_tensor_by_name(model, name)
            # if not curr_param:
            #     name = 'float_' + name # QAT name
            #     curr_param = fine_weight_tensor_by_name(self.model, name)
            # if curr_param is not None:
            weight_to_hook[name] = curr_param

    for epoch in range(0, args.epochs):
        # train for one epoch
        print('Timestamp Start epoch: {:%Y-%m-%d %H:%M:%S}'.format(
            datetime.datetime.now()))
        train(train_loader, model, criterion, optimizer, epoch, args, device,
              ml_logger, val_loader, mq, weight_to_hook)
        print('Timestamp End epoch: {:%Y-%m-%d %H:%M:%S}'.format(
            datetime.datetime.now()))

        if not args.lr_freeze:
            lr_scheduler.step()

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args, device)
        ml_logger.log_metric('Val Acc1', acc1, step='auto')

        # evaluate with k-means quantization
        # if args.model_freeze:
        # with mq.quantization_method('kmeans'):
        #     acc1_kmeans = validate(val_loader, model, criterion, args, device)
        #     ml_logger.log_metric('Val Acc1 kmeans', acc1_kmeans, epoch)

        # with mq.disable():
        #     acc1_nq = validate(val_loader, model, criterion, args, device)
        #     ml_logger.log_metric('Val Acc1 fp32', acc1_nq,  step='auto')

        if args.quantize:
            mq.log_quantizer_state(ml_logger, epoch)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        save_checkpoint(
            {
                'epoch':
                epoch + 1,
                'arch':
                args.arch,
                'state_dict':
                model.state_dict()
                if len(args.gpu_ids) == 1 else model.module.state_dict(),
                'best_acc1':
                best_acc1,
                'optimizer':
                optimizer.state_dict(),
            },
            is_best,
            datatime_str=datatime_str,
            suf_name=suf_name)