예제 #1
0
def test():
    net = MobileNet(amc=True)
    from compute_flops import print_model_param_nums, print_model_param_flops
    #x = torch.randn(1,3,224,224)
    #y = net(x)
    #print(y.size())
    print_model_param_nums(net)
    print_model_param_flops(net)
예제 #2
0
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    n_params = sum(p.numel() for p in net.parameters())/10**6
    print(f'Total params: {n_params:2f}M')
    print_model_param_flops(net, 32)
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
                         % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/ckpt.pth')
        best_acc = acc
예제 #3
0
파일: evaluate.py 프로젝트: zeta1999/sensAI
    def print_statistics(self):
        num_params = []
        num_flops = []

        print("\n===== Metrics for grouped model ==========================\n")

        for group_id, model in zip(self.group_info, self.model_list):
            n_params = sum(p.numel() for p in model.parameters()) / 10**6
            num_params.append(n_params)
            print(f'Grouped model for Class {group_id} '
                  f'Total params: {n_params:2f}M')
            num_flops.append(print_model_param_flops(model, 32))

        print(
            f"Average number of flops: {sum(num_flops) / len(num_flops) / 10**9 :3f} G"
        )
        print(
            f"Average number of param: {sum(num_params) / len(num_params)} M")
예제 #4
0
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=args.gpu_ids)
        if args.swa == True:
            swa_model = torch.nn.parallel.DistributedDataParallel(
                swa_model, device_ids=args.gpu_ids)
else:
    model = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth)
    if args.cuda:
        model.cuda()
    if len(args.gpu_ids) > 1:
        # model = torch.nn.DataParallel(model, device_ids=args.gpu_ids)
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=args.gpu_ids)

if args.dataset == 'imagenet':
    pruned_flops = print_model_param_flops(model, 224)

optimizer = optim.SGD(model.parameters(),
                      lr=args.lr,
                      momentum=args.momentum,
                      weight_decay=args.weight_decay)


def save_checkpoint(state, is_best, epoch, filepath, is_swa):
    if is_swa:
        torch.save(state, os.path.join(filepath, 'swa.pth.tar'))
    else:
        if epoch == 'init':
            filepath = os.path.join(filepath, 'init.pth.tar')
            torch.save(state, filepath)
        elif 'EB' in str(epoch):
예제 #5
0
        output = model(data)
        test_loss += F.cross_entropy(
            output, target, size_average=False).data  # sum up batch loss
        pred = output.data.max(
            1, keepdim=True)[1]  # get the index of the max log-probability
        correct += pred.eq(target.data.view_as(pred)).cpu().numpy().sum()

    test_loss /= len(test_loader.dataset)
    #print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
    #    test_loss, correct, len(test_loader.dataset),
    #    100. * correct / len(test_loader.dataset)))
    return correct / float(len(test_loader.dataset))


acc = test(model)

total_params = print_model_param_nums(model.cpu())
total_flops = print_model_param_flops(model.cpu(), 32)

results = {
    'load': args.load,
    'dataset': args.dataset,
    'model_name': args.model_name,
    'arch': 'mobilenetv1',
    'acc': acc,
    'cfg': model.cfg,
    'total_params': total_params,
    'total_flops': total_flops,
}
print(results)
예제 #6
0
# define loss function (criterion) and optimizer
num_classes = 1000

# Data loading code
train_loader, val_loader = \
    get_data_loader(args.data, train_batch_size=args.batch_size, test_batch_size=args.test_batch_size, workers=args.workers)

## loading pretrained model ##
assert args.load
assert os.path.isfile(args.load)
print("=> loading checkpoint '{}'".format(args.load))
checkpoint = torch.load(args.load)

model = mbnet(cfg=checkpoint['cfg'])
total_flops = print_model_param_flops(model, 224, multiply_adds=False) 
print(total_flops)

if args.use_cuda: 
    model.cuda()

selected_model_keys = [k for k in model.state_dict().keys() if not (k.endswith('.y') or k.endswith('.v') or k.startswith('net_params') or k.startswith('y_params') or k.startswith('v_params'))]
saved_model_keys = checkpoint['state_dict']
from collections import OrderedDict
new_state_dict = OrderedDict()
if len(selected_model_keys) == len(saved_model_keys):

    for k0, k1 in zip(selected_model_keys, saved_model_keys):
        new_state_dict[k0] = checkpoint['state_dict'][k1]   
    
    model_dict = model.state_dict()
예제 #7
0
                    metavar='PATH',
                    help='path to the model (default: none)')
parser.add_argument('--save',
                    default='',
                    type=str,
                    metavar='PATH',
                    help='path to save pruned model (default: none)')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

if not os.path.exists(args.save):
    os.makedirs(args.save)

model = resnet(depth=args.depth, dataset=args.dataset)
total_flops = print_model_param_flops(model, input_res=32)

if args.cuda:
    model.cuda()
if args.model:
    if os.path.isfile(args.model):
        print("=> loading checkpoint '{}'".format(args.model))
        checkpoint = torch.load(args.model)
        args.start_epoch = checkpoint['epoch']
        best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}".format(
            args.model, checkpoint['epoch'], best_prec1))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))
예제 #8
0
if args.model:
    if os.path.isfile(args.model):
        print("=> loading checkpoint '{}'".format(args.model))
        checkpoint = fix_robustness_ckpt(torch.load(args.model))
        # args.start_epoch = checkpoint['epoch']
        # best_prec1 = checkpoint['best_prec1']
        model.load_state_dict(checkpoint, strict=False)
        # print("=> loaded checkpoint '{}' (epoch {}) Prec1: {:f}"
        #       .format(args.model, checkpoint['epoch'], best_prec1))
    else:
        print("=> no checkpoint found at '{}'".format(args.resume))
        exit()

if args.dataset == 'imagenet':
    print('original model param: ', print_model_param_nums(model))
    print('original model flops: ', print_model_param_flops(model, 224, True))
else:
    print('original model param: ', print_model_param_nums(model))
    print('original model flops: ', print_model_param_flops(model, 32, True))

if args.cuda:
    model.cuda()

total = 0

for m in model.modules():
    if isinstance(m, nn.BatchNorm2d):
        total += m.weight.data.shape[0]

bn = torch.zeros(total)
index = 0
예제 #9
0
            model = models.__dict__[args.arch](pretrained=False, cfg=cfg_input)
    if args.cuda:
        model.cuda()
    if len(args.gpu_ids) > 1:
        model = torch.nn.DataParallel(model, device_ids=args.gpu_ids)
        # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=args.gpu_ids, find_unused_parameters=True)
else:
    model = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth)
    if args.cuda:
        model.cuda()
    if len(args.gpu_ids) > 1:
        model = torch.nn.DataParallel(model, device_ids=args.gpu_ids)
        # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=args.gpu_ids, find_unused_parameters=True)

if args.dataset == 'imagenet':
    pruned_flops = print_model_param_flops(model.cpu(), 224)
    model.cuda()


optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

def save_checkpoint(state, is_best, epoch, filepath):
    if epoch == 'init':
        filepath = os.path.join(filepath, 'init.pth.tar')
        torch.save(state, filepath)
    elif 'EB' in str(epoch):
        filepath = os.path.join(filepath, epoch+'.pth.tar')
        torch.save(state, filepath)
    else:
        filename = os.path.join(filepath, 'ckpt'+str(epoch)+'.pth.tar')
        torch.save(state, filename)
예제 #10
0
def main():
    global args, best_prec1, device
    args = parser.parse_args()

    batch_size = args.batch_size * max(1, args.num_gpus)
    args.lr = args.lr * (batch_size / 256.)
    print(batch_size, args.lr, args.num_gpus)

    num_classes = 1000
    num_training_samples = 1281167
    args.num_batches_per_epoch = num_training_samples // batch_size

    assert os.path.isfile(args.load) and args.load.endswith(".pth.tar")
    args.save = os.path.dirname(args.load)
    training_mode = 'retrain' if args.retrain else 'finetune'
    args.save = os.path.join(args.save, training_mode)

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    args.model_save_path = os.path.join(
        args.save, "epochs_{}_{}".format(args.epochs,
                                         os.path.basename(args.load)))
    args.distributed = args.world_size > 1

    ##########################################################
    ## create file handler which logs even debug messages
    #import logging
    #log = logging.getLogger()
    #log.setLevel(logging.INFO)

    #ch = logging.StreamHandler()
    #fh = logging.FileHandler(args.logging_file_path)

    #formatter = logging.Formatter('%(asctime)s - %(message)s')
    #ch.setFormatter(formatter)
    #fh.setFormatter(formatter)
    #log.addHandler(fh)
    #log.addHandler(ch)
    ##########################################################

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # Use CUDA
    args.use_cuda = torch.cuda.is_available() and not args.no_cuda

    # Random seed
    random.seed(0)
    torch.manual_seed(0)
    if args.use_cuda:
        torch.cuda.manual_seed_all(0)
        device = 'cuda'
        cudnn.benchmark = True
    else:
        device = 'cpu'

    if args.evaluate == 1:
        device = 'cuda:0'

    assert os.path.isfile(args.load)
    print("=> loading checkpoint '{}'".format(args.load))
    checkpoint = torch.load(args.load)

    model = mobilenetv2(cfg=checkpoint['cfg'])
    cfg = model.cfg

    total_params = print_model_param_nums(model.cpu())
    total_flops = print_model_param_flops(model.cpu(),
                                          224,
                                          multiply_adds=False)
    print(total_params, total_flops)

    if not args.distributed:
        model = torch.nn.DataParallel(model).to(device)
    else:
        model.to(device)
        model = torch.nn.parallel.DistributedDataParallel(model)

    ##### finetune #####
    if not args.retrain:
        model.load_state_dict(checkpoint['state_dict'])

    # define loss function (criterion) and optimizer
    if args.label_smoothing:
        criterion = CrossEntropyLabelSmooth(num_classes).to(device)
    else:
        criterion = nn.CrossEntropyLoss().to(device)

    ### all parameter ####
    no_wd_params, wd_params = [], []
    for name, param in model.named_parameters():
        if param.requires_grad:
            if ".bn" in name or '.bias' in name:
                no_wd_params.append(param)
            else:
                wd_params.append(param)
    no_wd_params = nn.ParameterList(no_wd_params)
    wd_params = nn.ParameterList(wd_params)

    optimizer = torch.optim.SGD([
        {
            'params': no_wd_params,
            'weight_decay': 0.
        },
        {
            'params': wd_params,
            'weight_decay': args.weight_decay
        },
    ],
                                args.lr,
                                momentum=args.momentum)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.model_save_path):
            print("=> loading checkpoint '{}'".format(args.model_save_path))
            checkpoint = torch.load(args.model_save_path)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.model_save_path, checkpoint['epoch']))
        else:
            pass

    # Data loading code
    train_loader, val_loader = \
        get_data_loader(args.data, train_batch_size=batch_size, test_batch_size=32, workers=args.workers)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        #adjust_learning_rate(optimizer, epoch)
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'cfg': cfg,
                #'m': args.m,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            },
            args.model_save_path)

        print('  + Number of params: %.3fM' % (total_params / 1e6))
        print('  + Number of FLOPs: %.3fG' % (total_flops / 1e9))
            if idx0.size == 1:
                idx0 = np.resize(idx0, (1, ))
            if idx1.size == 1:
                idx1 = np.resize(idx1, (1, ))
            w1 = mm0.weight._data[0][:, idx0.tolist(), :, :]
            w1 = w1[idx1.tolist(), :, :, :]
            params[mm1.weight.name] = w1

        elif isinstance(mm0, nn.Dense):
            if layer_id_in_cfg == len(cfg_mask):
                idx0 = np.squeeze(
                    np.argwhere(np.asarray(cfg_mask[-1].asnumpy())))
                if idx0.size == 1:
                    idx0 = np.resize(idx0, (1, ))
                params[mm1.weight.name] = mm0.weight._data[0][:, idx0]
                params[mm1.bias.name] = mm0.bias._data[0]
                layer_id_in_cfg += 1
                continue
            params[mm1.weight.name] = mm0.weight._data[0]
            params[mm1.bias.name] = mm0.bias._data[0]

#print(params)
pruned_model = '%s/%s-%s-pruned.params' % (args.save, args.dataset, model_name)
mxnet.ndarray.save(pruned_model, params)
newmodel.collect_params().load(pruned_model, ctx=context)
acc = test(newmodel)

num_parameters, flops = print_model_param_flops(newmodel, input_res=32)

print('\nTest-set accuracy after pruning: ', acc)
    criterion = nn.CrossEntropyLoss().cuda()

# Data loading code
train_loader, val_loader = \
    get_data_loader(args.data, train_batch_size=args.batch_size, test_batch_size=16, workers=args.workers)


## loading pretrained model ##
assert args.load
assert os.path.isfile(args.load)
print("=> loading checkpoint '{}'".format(args.load))
checkpoint = torch.load(args.load)

model = mbnet(cfg=checkpoint['cfg'])
total_params = print_model_param_nums(model)
total_flops = print_model_param_flops(model, 224, multiply_adds=False) 
print(total_params, total_flops)

if args.use_cuda: 
    model.cuda()

selected_model_keys = [k for k in model.state_dict().keys() if not (k.endswith('.y') or k.endswith('.v') or k.startswith('net_params') or k.startswith('y_params') or k.startswith('v_params'))]
saved_model_keys = checkpoint['state_dict']
from collections import OrderedDict
new_state_dict = OrderedDict()
if len(selected_model_keys) == len(saved_model_keys):

    for k0, k1 in zip(selected_model_keys, saved_model_keys):
        new_state_dict[k0] = checkpoint['state_dict'][k1]   
    
    model_dict = model.state_dict()
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()


def sp_mbnetv2(**kwargs):
    """
    Constructs a MobileNet V2 model
    """
    return SpMobileNetV2(**kwargs)


if __name__ == '__main__':
    net = sp_mbnetv2()
    x = Variable(torch.FloatTensor(2, 3, 224, 224))
    y = net(x)
    print(y.data.shape)

    print_cfg(net.cfg)

    from compute_flops import print_model_param_nums, print_model_param_flops
    total_flops = print_model_param_flops(net.cpu(), 224, multiply_adds=False)
                idx1 = np.squeeze(np.argwhere(np.asarray(end_mask.cpu().numpy())))
                # print('In shape: {:d}, Out shape {:d}.'.format(idx0.size, idx1.size))
                if idx0.size == 1:
                    idx0 = np.resize(idx0, (1,))
                if idx1.size == 1:
                    idx1 = np.resize(idx1, (1,))
                w1 = m0.weight.data[:, idx0.tolist(), :, :].clone()
                w1 = w1[idx1.tolist(), :, :, :].clone()
                m1.weight.data = w1.clone()
            elif isinstance(m0, nn.Linear):
                idx0 = np.squeeze(np.argwhere(np.asarray(start_mask.cpu().numpy())))
                if idx0.size == 1:
                    idx0 = np.resize(idx0, (1,))
                m1.weight.data = m0.weight.data[:, idx0].clone()
                m1.bias.data = m0.bias.data.clone()
        flop_ramained = compute_flops.print_model_param_flops(model=newmodel.cpu(), input_res=32, multiply_adds=False)

    torch.save({'cfg': cfg, 'state_dict': newmodel.state_dict()}, os.path.join(args.save,str(int(prune_ratio*100))+ 'pruned.pth.tar'))

# print(newmodel)
# model = newmodel
# test(model)
def test(model):
    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
    if args.dataset == 'cifar10':
        test_loader = torch.utils.data.DataLoader(
            datasets.CIFAR10('./data/dataset/cifar10', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])),
            batch_size=args.test_batch_size, shuffle=True, **kwargs)
    elif args.dataset == 'cifar100':
예제 #15
0
    'cfg': cfg,
    'state_dict': newmodel.state_dict()
}, os.path.join(args.save, 'pruned.pth.tar'))
# print(newmodel)
model = newmodel
print("after pruning")
acc = test(model)

# Calculate Flops and Params
origin_num_parameters = sum(
    [param.nelement() for param in origin_model.parameters()])
num_parameters = sum([param.nelement() for param in newmodel.parameters()])
param_reduction_percent = (
    (origin_num_parameters - num_parameters) / origin_num_parameters) * 100

origin_flops = print_model_param_flops(origin_model.cpu(), input_res=32) / 1e9
new_flops = print_model_param_flops(newmodel.cpu(), input_res=32) / 1e9
flops_reduction_percent = ((origin_flops - new_flops) / origin_flops) * 100

with open(os.path.join(args.save, "prune.txt"), "w") as fp:
    fp.write("Number of parameters Before: \n" + str(origin_num_parameters) +
             "\n" + "\n")
    fp.write("Number of parameters: \n" + str(num_parameters) + "\n" + "\n")
    fp.write("% of reduced parameters: \n" + str(param_reduction_percent) +
             "\n" + "\n" + "\n")

    fp.write("Number of Flops Before: \n" + str(origin_flops) + "G" + "\n" +
             "\n")
    fp.write("Number of Flops: \n" + str(new_flops) + "G" + "\n" + "\n")
    fp.write("% of reduced Flops: \n" + str(flops_reduction_percent) + "\n" +
             "\n")
예제 #16
0
    model = mwr.Model(num_classes,
                      input_size=image_size,
                      cfg=checkpoint['cfg'])
    model_ref = mwr.Model(num_classes,
                          input_size=image_size,
                          cfg=checkpoint['cfg'])
    # model = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth, cfg=checkpoint['cfg'])
    # model_ref = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth, cfg=checkpoint['cfg'])
    model_ref.load_state_dict(checkpoint['state_dict'])
    for m0, m1 in zip(model.modules(), model_ref.modules()):
        if isinstance(m0, models.channel_selection):
            m0.indexes.data = m1.indexes.data.clone()

    # model_base = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth)
    model_base = mwr.Model(num_classes, input_size=image_size)
    base_flops = print_model_param_flops(model_base, 32)
    pruned_flops = print_model_param_flops(model, 32)
    args.epochs = int(160 * (base_flops / pruned_flops))

if args.cuda:
    model.cuda()

# optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
optimizer = optim.SGD(model.parameters(),
                      lr=args.lr,
                      momentum=args.momentum,
                      weight_decay=args.weight_decay,
                      nesterov=True)

if args.resume:
    if os.path.isfile(args.resume):
    if isinstance(m, SpMbBlock) or (k == 2 and isinstance(m, SpConvBlock)):
        m.reset_yv_()
log.info('acc before splitting')
test(model)

for epoch in range(1, 1 + args.epochs):
    if epoch % 2 == 0:
        for param_group in optimizer_v.param_groups:
            param_group['lr'] *= 0.2
    min_eig_vals, min_eig_vecs = train(epoch)
    #break

########################################
##### select neurons ######
########################################
print_model_param_flops(model.cpu(), 32)
model.to(device)

total = 0
for m in min_eig_vals:
    total += len(m)

cfg_grow = []
cfg_mask = []

block_weigths_norm = []
if args.energy or args.params:
    ## flops ##
    cfg = model.cfg

    params_inc_per_neuron, flops_inc_per_neuron = [], []
def main():
    global best_prec1, log

    batch_size = args.batch_size * max(1, args.num_gpus)
    args.lr = args.lr * (batch_size // 256)
    print(batch_size, args.lr, args.num_gpus)

    num_classes = 1000
    num_training_samples = 1281167
    args.num_batches_per_epoch = num_training_samples // batch_size

    assert args.exp_name
    args.save = os.path.join(args.save, args.exp_name)
    if not os.path.exists(args.save):
        os.makedirs(args.save)

    hyper_str = "run_{}_lr_{}_decay_{}_b_{}_gpu_{}".format(args.epochs, args.lr, \
                                args.lr_mode, batch_size, args.num_gpus)

    ## bn-based pruning base model ##
    if args.sr:
        hyper_str = "{}_sr_grow_{}_s_{}".format(hyper_str, args.m, args.s)
    ## using amc configuration ##
    elif args.amc:
        hyper_str = "{}_amc".format(hyper_str)
    elif args.sp:
        hyper_str = "{}_sp_base_{}".format(hyper_str, args.sp_cfg)
    else:
        hyper_str = "{}_grow_{}".format(hyper_str, args.m)

    args.model_save_path = \
            os.path.join(args.save, 'mbv1_{}.pth.tar'.format(hyper_str))

    #args.logging_file_path = \
    #        os.path.join(args.save, 'mbv1_{}.log'.format(hyper_str))
    #print(args.model_save_path, args.logging_file_path)

    ##########################################################
    ## create file handler which logs even debug messages
    #import logging
    #log = logging.getLogger()
    #log.setLevel(logging.INFO)

    #ch = logging.StreamHandler()
    #fh = logging.FileHandler(args.logging_file_path)

    #formatter = logging.Formatter('%(asctime)s - %(message)s')
    #ch.setFormatter(formatter)
    #fh.setFormatter(formatter)
    #log.addHandler(fh)
    #log.addHandler(ch)
    #########################################################
    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    # Use CUDA
    use_cuda = torch.cuda.is_available()
    args.use_cuda = use_cuda

    # Random seed
    random.seed(0)
    torch.manual_seed(0)
    if use_cuda:
        torch.cuda.manual_seed_all(0)
        device = 'cuda'
        cudnn.benchmark = True
    else:
        device = 'cpu'

    if args.evaluate == 1:
        device = 'cuda:0'

    if args.sp:
        model = mbnet(default=args.sp_cfg)
    else:
        #model = mobilenetv1(amc=args.amc, m=args.m)
        model = mbnet(amc=args.amc, m=args.m)
        print(model.cfg)

    cfg = model.cfg

    total_params = print_model_param_nums(model.cpu())
    total_flops = print_model_param_flops(model.cpu(),
                                          224,
                                          multiply_adds=False)
    print(total_params, total_flops)

    if not args.distributed:
        model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # define loss function (criterion) and optimizer
    if args.label_smoothing:
        criterion = CrossEntropyLabelSmooth(num_classes).cuda()
    else:
        criterion = nn.CrossEntropyLoss().cuda()

    ### all parameter ####
    no_wd_params, wd_params = [], []
    for name, param in model.named_parameters():
        if param.requires_grad:
            if ".bn" in name or '.bias' in name:
                no_wd_params.append(param)
            else:
                wd_params.append(param)
    no_wd_params = nn.ParameterList(no_wd_params)
    wd_params = nn.ParameterList(wd_params)

    optimizer = torch.optim.SGD([
        {
            'params': no_wd_params,
            'weight_decay': 0.
        },
        {
            'params': wd_params,
            'weight_decay': args.weight_decay
        },
    ],
                                args.lr,
                                momentum=args.momentum,
                                nesterov=True)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.model_save_path):
            print("=> loading checkpoint '{}'".format(args.model_save_path))
            checkpoint = torch.load(args.model_save_path)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.model_save_path, checkpoint['epoch']))
        else:
            pass
            #print("=> no checkpoint found at '{}'".format(args.model_save_path))

    # Data loading code
    train_loader, val_loader = \
        get_data_loader(args.data, train_batch_size=batch_size, test_batch_size=32, workers=args.workers)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)

        #adjust_learning_rate(optimizer, epoch)
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        # evaluate on validation set
        prec1 = validate(val_loader, model, criterion)

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'cfg': cfg,
                'sr': args.sr,
                'amc': args.amc,
                's': args.s,
                'args': args,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, args.model_save_path)
예제 #19
0
        ])),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              **kwargs)

model = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth)

if args.scratch:
    checkpoint = torch.load(args.scratch)
    model = models.__dict__[args.arch](dataset=args.dataset,
                                       depth=args.depth,
                                       cfg=checkpoint['cfg'])

model_ref = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth)

flops_std = print_model_param_flops(model_ref, 32)
flops_small = print_model_param_flops(model, 32)
args.epochs = int(160 * (flops_std / flops_small))

if args.cuda:
    model.cuda()

optimizer = optim.SGD(model.parameters(),
                      lr=args.lr,
                      momentum=args.momentum,
                      weight_decay=args.weight_decay)

if args.resume:
    if os.path.isfile(args.resume):
        print("=> loading checkpoint '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
예제 #20
0
    checkpoint = torch.load(args.scratch)
    if args.dataset == 'imagenet':
        model = models.__dict__[args.arch](pretrained=False, cfg=checkpoint['cfg'])
        model_ref = models.__dict__[args.arch](pretrained=False, cfg=checkpoint['cfg'])
        model_ref.load_state_dict(checkpoint['state_dict'])
    else:
        model = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth, cfg=checkpoint['cfg'])
        model_ref = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth, cfg=checkpoint['cfg'])
        model_ref.load_state_dict(checkpoint['state_dict'])
    for m0, m1 in zip(model.modules(), model_ref.modules()):
        if isinstance(m0, models.channel_selection):
            m0.indexes.data = m1.indexes.data.clone()

    if args.dataset == 'imagenet':
        model_base = model
        base_flops = print_model_param_flops(model_base, 224)
        pruned_flops = print_model_param_flops(model, 224)
    else:
        pass
        # model_base = models.__dict__[args.arch](dataset=args.dataset, depth=args.depth)
        # base_flops = print_model_param_flops(model_base, 32)
        # pruned_flops = print_model_param_flops(model, 32)
        # args.epochs = int(160 * (base_flops / pruned_flops))

if args.cuda:
    model.cuda()

optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

if args.resume:
    if os.path.isfile(args.resume):