Ejemplo n.º 1
0
def flops_counter(args):
    # model speed up
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, val_loader, criterion = get_data(args)

    if args.pruner != 'AutoCompressPruner':
        if args.model == 'LeNet':
            model = LeNet().to(device)
        elif args.model == 'vgg16':
            model = VGG(depth=16).to(device)
        elif args.model == 'resnet18':
            model = models.resnet18(pretrained=False,
                                    num_classes=10).to(device)
        elif args.model == 'mobilenet_v2':
            model = models.mobilenet_v2(pretrained=False).to(device)

        def evaluator(model):
            return test(model, device, criterion, val_loader)

        model.load_state_dict(
            torch.load(
                os.path.join(args.experiment_data_dir,
                             'model_fine_tuned.pth')))
        masks_file = os.path.join(args.experiment_data_dir, 'mask.pth')

        dummy_input = get_dummy_input(args, device)

        m_speedup = ModelSpeedup(model, dummy_input, masks_file, device)
        m_speedup.speedup_model()
        evaluation_result = evaluator(model)
        print('Evaluation result (speed up model): %s' % evaluation_result)

        with open(os.path.join(args.experiment_data_dir,
                               'performance.json')) as f:
            result = json.load(f)

        result['speedup'] = evaluation_result
        with open(os.path.join(args.experiment_data_dir, 'performance.json'),
                  'w+') as f:
            json.dump(result, f)

        torch.save(
            model.state_dict(),
            os.path.join(args.experiment_data_dir, 'model_speed_up.pth'))
        print('Speed up model saved to %s', args.experiment_data_dir)
    else:
        model = torch.load(
            os.path.join(args.experiment_data_dir, 'model_fine_tuned.pth'))
        model.eval()
        flops, params = count_flops_params(model, (1, 3, 32, 32))
        with open(os.path.join(args.experiment_data_dir, 'flops.json'),
                  'w+') as f:
            json.dump({'FLOPS': int(flops), 'params': int(params)}, f)
Ejemplo n.º 2
0
def model_inference(config):
    model_trained = './experiment_data/resnet_bn/model_fine_tuned_first.pth'
    rn50 = resnet50()
    m_paras = torch.load(model_trained)
    ##delete mask in pth
    m_new = collections.OrderedDict()
    mask = dict()
    for key in m_paras:
        if 'weight_mask_b' in key: continue
        if 'weight_mask' in key:
            if 'module_added' not in key:
                mask[key.replace('.weight_mask', '')] = dict()
                mask[key.replace('.weight_mask', '')]['weight'] = m_paras[key]
                mask[key.replace('.weight_mask', '')]['bias'] = m_paras[key]
            else:
                mask[key.replace('.relu1.module_added.weight_mask',
                                 '.bn3')] = {}
                mask[key.replace('.relu1.module_added.weight_mask',
                                 '.bn3')]['weight'] = m_paras[key]
                mask[key.replace('.relu1.module_added.weight_mask',
                                 '.bn3')]['bias'] = m_paras[key]
                if '0.relu1' in key:
                    mask[key.replace('relu1.module_added.weight_mask',
                                     'downsample.1')] = {}
                    mask[key.replace('relu1.module_added.weight_mask',
                                     'downsample.1')]['weight'] = m_paras[key]
                    mask[key.replace('relu1.module_added.weight_mask',
                                     'downsample.1')]['bias'] = m_paras[key]
            continue
        if 'module_added' in key:
            continue
        elif 'module' in key:
            m_new[key.replace('module.', '')] = m_paras[key]
        else:
            m_new[key] = m_paras[key]
    for key in mask:
        #modify the weight and bias of model with pruning
        m_new[key + '.weight'] = m_new[key + '.weight'].data.mul(
            mask[key]['weight'])
        m_new[key + '.bias'] = m_new[key + '.bias'].data.mul(mask[key]['bias'])
    rn50.load_state_dict(m_new)
    rn50.cuda()
    rn50.eval()
    torch.save(mask, 'taylor_mask.pth')
    mask_file = './taylor_mask.pth'
    dummy_input = torch.randn(64, 3, 224, 224).cuda()
    use_mask_out = use_speedup_out = None
    rn = rn50
    rn_mask_out = rn(dummy_input)
    model = rn50
    if use_mask:
        torch.onnx.export(model,
                          dummy_input,
                          'resnet_masked_taylor_1700.onnx',
                          export_params=True,
                          opset_version=12,
                          do_constant_folding=True,
                          input_names=['inputs'],
                          output_names=['proba'],
                          dynamic_axes={
                              'inputs': [0],
                              'mask': [0]
                          },
                          keep_initializers_as_inputs=True)

        start = time.time()
        for _ in range(32):
            use_mask_out = model(dummy_input)
        elapsed_t = time.time() - start
        print('elapsed time when use mask: ', elapsed_t)
        _logger.info(
            'for batch size 64 and with 32 runs, the elapsed time is {}'.
            format(elapsed_t))
    print('before speed up===================')
    flops, paras = count_flops_params(model, (1, 3, 224, 224))
    _logger.info(
        'flops and parameters before speedup is {} FLOPS and {} params'.format(
            flops, paras))
    if use_speedup:
        dummy_input.cuda()
        m_speedup = ModelSpeedup(model, dummy_input, mask_file, 'cuda')
        m_speedup.speedup_model()
        print('==' * 20)
        print('Start inference')
        torch.onnx.export(model,
                          dummy_input,
                          'resnet_taylor_1700.onnx',
                          export_params=True,
                          opset_version=12,
                          do_constant_folding=True,
                          input_names=['inputs'],
                          output_names=['proba'],
                          dynamic_axes={
                              'inputs': [0],
                              'mask': [0]
                          },
                          keep_initializers_as_inputs=True)
        start = time.time()
        for _ in range(32):
            use_speedup_out = model(dummy_input)
        elasped_t1 = time.time() - start
        print('elapsed time when use speedup: ', elasped_t1)
        _logger.info(
            'elasped time with batch_size 64 and in 32 runs is {}'.format(
                elasped_t1))
    #print('After speedup model is ',model)
    _logger.info('model structure after speedup is ====')
    _logger.info(model)
    print('=================')
    print('After speedup')
    flops, paras = count_flops_params(model, (1, 3, 224, 224))
    _logger.info(
        'After speedup flops are {} and number of parameters are {}'.format(
            flops, paras))
    if compare_results:
        print(rn_mask_out)
        print('another is', use_speedup_out)
        if torch.allclose(rn_mask_out, use_speedup_out, atol=1e-6):  #-07):
            print('the outputs from use_mask and use_speedup are the same')
        else:
            raise RuntimeError(
                'the outputs from use_mask and use_speedup are different')
    # start the accuracy check
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        start = time.time()
        evaluate(model,
                 criterion,
                 data_loader_test,
                 device="cuda",
                 print_freq=20)
        print('elapsed time is ', time.time() - start)
def main(args):
    # prepare dataset
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, val_loader, criterion = get_data(args)
    model, optimizer = get_trained_model_optimizer(args, device, train_loader,
                                                   val_loader, criterion)

    def short_term_fine_tuner(model, epochs=1):
        for epoch in range(epochs):
            train(args, model, device, train_loader, criterion, optimizer,
                  epoch)

    def trainer(model, optimizer, criterion, epoch, callback):
        return train(args,
                     model,
                     device,
                     train_loader,
                     criterion,
                     optimizer,
                     epoch=epoch,
                     callback=callback)

    def evaluator(model):
        return test(model, device, criterion, val_loader)

    # used to save the performance of the original & pruned & finetuned models
    result = {'flops': {}, 'params': {}, 'performance': {}}

    flops, params = count_flops_params(model, get_input_size(args.dataset))
    result['flops']['original'] = flops
    result['params']['original'] = params

    evaluation_result = evaluator(model)
    print('Evaluation result (original model): %s' % evaluation_result)
    result['performance']['original'] = evaluation_result

    # module types to prune, only "Conv2d" supported for channel pruning
    if args.base_algo in ['l1', 'l2']:
        op_types = ['Conv2d']
    elif args.base_algo == 'level':
        op_types = ['default']

    config_list = [{'sparsity': args.sparsity, 'op_types': op_types}]
    dummy_input = get_dummy_input(args, device)

    if args.pruner == 'L1FilterPruner':
        pruner = L1FilterPruner(model, config_list)
    elif args.pruner == 'L2FilterPruner':
        pruner = L2FilterPruner(model, config_list)
    elif args.pruner == 'ActivationMeanRankFilterPruner':
        pruner = ActivationMeanRankFilterPruner(model, config_list)
    elif args.pruner == 'ActivationAPoZRankFilterPruner':
        pruner = ActivationAPoZRankFilterPruner(model, config_list)
    elif args.pruner == 'NetAdaptPruner':
        pruner = NetAdaptPruner(model,
                                config_list,
                                short_term_fine_tuner=short_term_fine_tuner,
                                evaluator=evaluator,
                                base_algo=args.base_algo,
                                experiment_data_dir=args.experiment_data_dir)
    elif args.pruner == 'ADMMPruner':
        # users are free to change the config here
        if args.model == 'LeNet':
            if args.base_algo in ['l1', 'l2']:
                config_list = [{
                    'sparsity': 0.8,
                    'op_types': ['Conv2d'],
                    'op_names': ['conv1']
                }, {
                    'sparsity': 0.92,
                    'op_types': ['Conv2d'],
                    'op_names': ['conv2']
                }]
            elif args.base_algo == 'level':
                config_list = [{
                    'sparsity': 0.8,
                    'op_names': ['conv1']
                }, {
                    'sparsity': 0.92,
                    'op_names': ['conv2']
                }, {
                    'sparsity': 0.991,
                    'op_names': ['fc1']
                }, {
                    'sparsity': 0.93,
                    'op_names': ['fc2']
                }]
        else:
            raise ValueError('Example only implemented for LeNet.')
        pruner = ADMMPruner(model,
                            config_list,
                            trainer=trainer,
                            num_iterations=2,
                            training_epochs=2)
    elif args.pruner == 'SimulatedAnnealingPruner':
        pruner = SimulatedAnnealingPruner(
            model,
            config_list,
            evaluator=evaluator,
            base_algo=args.base_algo,
            cool_down_rate=args.cool_down_rate,
            experiment_data_dir=args.experiment_data_dir)
    elif args.pruner == 'AutoCompressPruner':
        pruner = AutoCompressPruner(
            model,
            config_list,
            trainer=trainer,
            evaluator=evaluator,
            dummy_input=dummy_input,
            num_iterations=3,
            optimize_mode='maximize',
            base_algo=args.base_algo,
            cool_down_rate=args.cool_down_rate,
            admm_num_iterations=30,
            admm_training_epochs=5,
            experiment_data_dir=args.experiment_data_dir)
    else:
        raise ValueError("Pruner not supported.")

    # Pruner.compress() returns the masked model
    # but for AutoCompressPruner, Pruner.compress() returns directly the pruned model
    model = pruner.compress()
    evaluation_result = evaluator(model)
    print('Evaluation result (masked model): %s' % evaluation_result)
    result['performance']['pruned'] = evaluation_result

    if args.save_model:
        pruner.export_model(
            os.path.join(args.experiment_data_dir, 'model_masked.pth'),
            os.path.join(args.experiment_data_dir, 'mask.pth'))
        print('Masked model saved to %s', args.experiment_data_dir)

    # model speed up
    if args.speed_up:
        if args.pruner != 'AutoCompressPruner':
            if args.model == 'LeNet':
                model = LeNet().to(device)
            elif args.model == 'vgg16':
                model = VGG(depth=16).to(device)
            elif args.model == 'resnet18':
                model = ResNet18().to(device)
            elif args.model == 'resnet50':
                model = ResNet50().to(device)
            elif args.model == 'mobilenet_v2':
                model = models.mobilenet_v2(pretrained=False).to(device)

            model.load_state_dict(
                torch.load(
                    os.path.join(args.experiment_data_dir,
                                 'model_masked.pth')))
            masks_file = os.path.join(args.experiment_data_dir, 'mask.pth')

            m_speedup = ModelSpeedup(model, dummy_input, masks_file, device)
            m_speedup.speedup_model()
            evaluation_result = evaluator(model)
            print('Evaluation result (speed up model): %s' % evaluation_result)
            result['performance']['speedup'] = evaluation_result

            torch.save(
                model.state_dict(),
                os.path.join(args.experiment_data_dir, 'model_speed_up.pth'))
            print('Speed up model saved to %s', args.experiment_data_dir)
        flops, params = count_flops_params(model, get_input_size(args.dataset))
        result['flops']['speedup'] = flops
        result['params']['speedup'] = params

    if args.fine_tune:
        if args.dataset == 'mnist':
            optimizer = torch.optim.Adadelta(model.parameters(), lr=1)
            scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
        elif args.dataset == 'cifar10' and args.model == 'vgg16':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=0.01,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[
                                        int(args.fine_tune_epochs * 0.5),
                                        int(args.fine_tune_epochs * 0.75)
                                    ],
                                    gamma=0.1)
        elif args.dataset == 'cifar10' and args.model == 'resnet18':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=0.1,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[
                                        int(args.fine_tune_epochs * 0.5),
                                        int(args.fine_tune_epochs * 0.75)
                                    ],
                                    gamma=0.1)
        elif args.dataset == 'cifar10' and args.model == 'resnet50':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=0.1,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[
                                        int(args.fine_tune_epochs * 0.5),
                                        int(args.fine_tune_epochs * 0.75)
                                    ],
                                    gamma=0.1)
        best_acc = 0
        for epoch in range(args.fine_tune_epochs):
            train(args, model, device, train_loader, criterion, optimizer,
                  epoch)
            scheduler.step()
            acc = evaluator(model)
            if acc > best_acc:
                best_acc = acc
                torch.save(
                    model.state_dict(),
                    os.path.join(args.experiment_data_dir,
                                 'model_fine_tuned.pth'))

    print('Evaluation result (fine tuned): %s' % best_acc)
    print('Fined tuned model saved to %s', args.experiment_data_dir)
    result['performance']['finetuned'] = best_acc

    with open(os.path.join(args.experiment_data_dir, 'result.json'),
              'w+') as f:
        json.dump(result, f)
Ejemplo n.º 4
0
def model_inference(config):
    masks_file = './speedup_test/mask_new.pth'
    shape_mask = './speedup_test/mask_new.pth'
    org_mask = './speedup_test/mask.pth'
    rn50 = models.resnet50()
    m_paras = torch.load('./speedup_test/model_fine_tuned.pth')
    ##delete mask in pth
    m_new = collections.OrderedDict()
    for key in m_paras:
        if 'mask' in key: continue
        if 'module' in key:
            m_new[key.replace('module.', '')] = m_paras[key]
        else:
            m_new[key] = m_paras[key]
    rn50.load_state_dict(m_new)
    rn50.cuda()
    rn50.eval()

    dummy_input = torch.randn(64, 3, 224, 224).cuda()
    use_mask_out = use_speedup_out = None
    rn = rn50
    apply_compression_results(rn, org_mask, 'cuda')
    rn_mask_out = rn(dummy_input)
    model = rn50
    # must run use_mask before use_speedup because use_speedup modify the model
    if use_mask:
        apply_compression_results(model, masks_file, 'cuda')
        torch.onnx.export(model,
                          dummy_input,
                          'resnet_masked.onnx',
                          export_params=True,
                          opset_version=12,
                          do_constant_folding=True,
                          input_names=['inputs'],
                          output_names=['proba'],
                          dynamic_axes={
                              'inputs': [0],
                              'mask': [0]
                          },
                          keep_initializers_as_inputs=True)

        start = time.time()
        for _ in range(32):
            use_mask_out = model(dummy_input)
        print('elapsed time when use mask: ', time.time() - start)
    print('Model is ', model)
    print('before speed up===================')
    #    print(para)
    #    print(model.state_dict()[para])
    #    print(model.state_dict()[para].shape)
    flops, paras = count_flops_params(model, (1, 3, 224, 224))
    print(
        'flops and parameters before speedup is {} FLOPS and {} params'.format(
            flops, paras))
    if use_speedup:
        dummy_input.cuda()
        m_speedup = ModelSpeedup(model, dummy_input, shape_mask, 'cuda')
        m_speedup.speedup_model()
        print('==' * 20)
        print('Start inference')
        torch.onnx.export(model,
                          dummy_input,
                          'resnet_fpgm.onnx',
                          export_params=True,
                          opset_version=12,
                          do_constant_folding=True,
                          input_names=['inputs'],
                          output_names=['proba'],
                          dynamic_axes={
                              'inputs': [0],
                              'mask': [0]
                          },
                          keep_initializers_as_inputs=True)
        start = time.time()
        for _ in range(32):
            use_speedup_out = model(dummy_input)
        print('elapsed time when use speedup: ', time.time() - start)
    print('After speedup model is ', model)
    print('=================')
    print('After speedup')
    flops, paras = count_flops_params(model, (1, 3, 224, 224))
    print(
        'flops and parameters before speedup is {} FLOPS and {} params'.format(
            flops, paras))
    #for para in model.state_dict():
    #    print(para)
    #    print(model.state_dict()[para])
    #    print(model.state_dict()[para].shape)
    if compare_results:
        print(rn_mask_out)
        print('another is', use_speedup_out)
        if torch.allclose(rn_mask_out, use_speedup_out, atol=1e-6):  #-07):
            print('the outputs from use_mask and use_speedup are the same')
        else:
            raise RuntimeError(
                'the outputs from use_mask and use_speedup are different')
    # start the accuracy check
    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        start = time.time()
        evaluate(model,
                 criterion,
                 data_loader_test,
                 device="cuda",
                 print_freq=20)
        print('elapsed time is ', time.time() - start)
Ejemplo n.º 5
0
def main(args):
    # prepare dataset
    torch.manual_seed(0)
    model = resnet50()
    model.load_state_dict(torch.load(args.pretrained_model_dir))
    inited = init_distributed(True)  #use nccl fro communication
    print('all cudas numbers are ', get_world_size())
    distributed = (get_world_size() > 1) and inited
    paral = get_world_size()
    #device = torch.device('cuda',args.local_rank) if distributed else torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #args.rank = get_rank()
    device = set_device(args.cuda, args.local_rank)
    #write to tensorboard
    logger = Logger("logs/" + str(args.local_rank))
    print(distributed)
    print('local rank is {}'.format(args.local_rank))
    train_loader, val_loader, criterion = get_data(args.dataset, args.data_dir,
                                                   args.batch_size,
                                                   args.test_batch_size)
    print('to distribute ', distributed)
    model = model.to(args.local_rank)
    if distributed:
        model = DDP(model, device_ids=[args.local_rank
                                       ])  #, output_device=args.local_rank)
    #elif args.mgpu:
    #    model=torch.nn.DataParallel(model).cuda()
    else:
        model = model.cuda()
    #model = torch.nn.DataParallel(model).cuda()
    criterion = criterion.to(args.local_rank)
    #for module in model.named_modules():
    #    print('%'*20)
    #    print(module)
    # module types to prune, only "BatchNorm2d" supported for channel pruning with Taylor
    config_step = {
        'bn_statics': args.bn_statistics,
        'fre': args.freq,
        'tot_pru': args.tot_pru,
        'save_path': args.experiment_data_dir
    }
    config_list = [{
        'op_types': ['BatchNorm2d', 'ReLU'],
        'must_names': 'layer',
        'include_names': ['bn1', 'bn2', 'relu1']
    }]
    dummy_input = get_dummy_input(args, args.local_rank)
    if args.pruner == 'TaylorPruner':
        pruner = TaylorPruner(device,
                              model,
                              config_list,
                              config_step,
                              dependency_aware=False,
                              optimizer=None)
    else:
        raise ValueError("Pruner not supported.")

    # Pruner.compress() returns the masked model
    model = pruner.compress()
    if args.pruner == 'TaylorPruner':
        params_update = [
            param[1] for param in model.named_parameters()
            if 'mask' not in param[0]
        ]
        not_updates = [
            param[1] for param in model.named_parameters()
            if 'mask' in param[0]
        ]
        #print(params_update)
        params_all = [{
            'params': params_update
        }, {
            'params': not_updates,
            'lr': 0.0
        }]
    if args.fine_tune:
        if args.dataset in ['imagenet'] and args.model == 'resnet50':
            optimizer = torch.optim.SGD(params_all,
                                        lr=0.01,
                                        momentum=0.5,
                                        weight_decay=1e-8)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[
                                        int(args.fine_tune_epochs * 0.3),
                                        int(args.fine_tune_epochs * 0.6),
                                        int(args.fine_tune_epochs * 0.8)
                                    ],
                                    gamma=0.1)
            pruner.optimizer = optimizer
            pruner.patch_optimizer(pruner.masker.calc_contributions)
            pruner.keep_org_step()
        else:
            raise ValueError

    def short_term_fine_tuner(model, epochs=1):
        for epoch in range(epochs):
            train(pruner, args, model, device, train_loader, criterion,
                  optimizer, epoch, logger)

    def trainer(pruner, model, optimizer, criterion, epoch, callback):
        return train(pruner,
                     args,
                     model,
                     device,
                     train_loader,
                     criterion,
                     optimizer,
                     epoch=epoch,
                     logger=logger,
                     callback=callback)

    def evaluator(model, step):
        return test(model, device, criterion, val_loader, step, logger)

    # used to save the performance of the original & pruned & finetuned models
    result = {'flops': {}, 'params': {}, 'performance': {}}

    #print(model)
    for module in model.named_modules():
        print('DEBUG===name of module is ', module[0])
        print('buffers are: ', [buff for buff in module[1].buffers()])
        print('para are: ', [para for para in module[1].named_parameters()])
    flops, params = count_flops_params(model, get_input_size(args.dataset))
    result['flops']['original'] = flops
    result['params']['original'] = params

    evaluation_result = evaluator(model, 0)
    print('Evaluation result (original model): %s' % evaluation_result)
    result['performance']['original'] = evaluation_result

    if args.local_rank == 0 and args.save_model:
        pruner.export_model(
            os.path.join(args.experiment_data_dir, 'model_masked.pth'),
            os.path.join(args.experiment_data_dir, 'mask.pth'))
        print('Masked model saved to %s', args.experiment_data_dir)

    def wrapped(module):
        return isinstance(module, BNTaylorPrunerMasker)

    wrap_mask = [
        module for module in model.named_modules() if wrapped(module[1])
    ]
    for idx, mm in enumerate(wrap_mask):
        print('====****' * 10, idx)
        print(mm[0])
        print(mm[1].state_dict().keys())
        print('weight mask is ', mm[1].state_dict()['weight_mask'])
        if 'bias_mask' in mm[1].state_dict():
            print('bias mask is ', mm[1].state_dict()['bias_mask'])

    if args.mgpu: model = torch.nn.DataParallel(model).cuda()
    print('local rank is', args.local_rank)
    if args.fine_tune:
        best_acc = 0
        for epoch in range(args.fine_tune_epochs):
            print('start fine tune for epoch {}/{}'.format(
                epoch, args.fine_tune_epochs))
            stime = time.time()
            train(pruner, args, model, args.local_rank, train_loader,
                  criterion, optimizer, epoch, logger)
            scheduler.step()
            acc = evaluator(model, epoch)
            print('end fine tune for epoch {}/{} for {} seconds'.format(
                epoch, args.fine_tune_epochs,
                time.time() - stime))
            if acc > best_acc and args.local_rank == 0:
                best_acc = acc
                torch.save(
                    model,
                    os.path.join(args.experiment_data_dir, args.model,
                                 'finetune_model.pt'))
                torch.save(
                    model.state_dict(),
                    os.path.join(args.experiment_data_dir,
                                 'model_fine_tuned.pth'))

    print('Evaluation result (fine tuned): %s' % best_acc)
    print('Fined tuned model saved to %s', args.experiment_data_dir)
    result['performance']['finetuned'] = best_acc

    if args.local_rank == 0:
        with open(os.path.join(args.experiment_data_dir, 'result.json'),
                  'w+') as f:
            json.dump(result, f)
Ejemplo n.º 6
0
def main(args):
    # prepare dataset
    torch.manual_seed(0)
    #device = torch.device('cuda',args.local_rank) if distributed else torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = set_device(args.cuda, args.local_rank)
    inited = init_distributed(True)  #use nccl fro communication
    print('all cudas numbers are ', get_world_size())
    distributed = (get_world_size() > 1) and inited
    paral = get_world_size()
    args.rank = get_rank()
    #write to tensorboard
    logger = Logger("logs/" + str(args.rank))
    print(distributed)
    #device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('device is', device)
    print('rank is {} local rank is {}'.format(args.rank, args.local_rank))
    train_loader, val_loader, criterion = get_data(args.dataset, args.data_dir,
                                                   args.batch_size,
                                                   args.test_batch_size)
    model = torchvision.models.resnet50(pretrained=True)
    model = model.cuda()
    print('to distribute ', distributed)
    if distributed:
        model = DDP(model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank)
    #model = torch.nn.DataParallel(model).cuda()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=0.1,
                                momentum=0.9,
                                weight_decay=5e-4)
    scheduler = MultiStepLR(optimizer,
                            milestones=[
                                int(args.pretrain_epochs * 0.5),
                                int(args.pretrain_epochs * 0.75)
                            ],
                            gamma=0.1)

    criterion = criterion.cuda()

    #model, optimizer = get_trained_model_optimizer(args, device, train_loader, val_loader, criterion)

    def short_term_fine_tuner(model, epochs=1):
        for epoch in range(epochs):
            train(args, model, device, train_loader, criterion, optimizer,
                  epoch, logger)

    def trainer(model, optimizer, criterion, epoch, callback):
        return train(args,
                     model,
                     device,
                     train_loader,
                     criterion,
                     optimizer,
                     epoch=epoch,
                     logger=logger,
                     callback=callback)

    def evaluator(model, step):
        return test(model, device, criterion, val_loader, step, logger)

    # used to save the performance of the original & pruned & finetuned models
    result = {'flops': {}, 'params': {}, 'performance': {}}

    flops, params = count_flops_params(model, get_input_size(args.dataset))
    result['flops']['original'] = flops
    result['params']['original'] = params

    evaluation_result = evaluator(model, 0)
    print('Evaluation result (original model): %s' % evaluation_result)
    result['performance']['original'] = evaluation_result

    # module types to prune, only "Conv2d" supported for channel pruning
    if args.base_algo in ['l1', 'l2']:
        op_types = ['Conv2d']
    elif args.base_algo == 'level':
        op_types = ['default']

    config_list = [{
        'sparsity': args.sparsity,
        'op_types': op_types,
        'exclude_names': 'downsample'
    }]
    dummy_input = get_dummy_input(args, device)

    if args.pruner == 'FPGMPruner':
        pruner = MyPruner(model, config_list)
    else:
        raise ValueError("Pruner not supported.")

    # Pruner.compress() returns the masked model
    model = pruner.compress()
    evaluation_result = evaluator(model, 0)
    print('Evaluation result (masked model): %s' % evaluation_result)
    result['performance']['pruned'] = evaluation_result

    if args.rank == 0 and args.save_model:
        pruner.export_model(
            os.path.join(args.experiment_data_dir, 'model_masked.pth'),
            os.path.join(args.experiment_data_dir, 'mask.pth'))
        print('Masked model saved to %s', args.experiment_data_dir)

    def wrapped(module):
        return isinstance(module, BNWrapper) or isinstance(
            module, PrunerModuleWrapper)

    wrap_mask = [
        module for module in model.named_modules() if wrapped(module[1])
    ]
    for mm in wrap_mask:
        print('====****' * 10)
        print(mm[0])
        print(mm[1].state_dict().keys())
        print('weight mask is ', mm[1].state_dict()['weight_mask'])
        if 'bias_mask' in mm[1].state_dict():
            print('bias mask is ', mm[1].state_dict()['bias_mask'])

    if args.fine_tune:
        if args.dataset in ['imagenet'] and args.model == 'resnet50':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=0.01,
                                        momentum=0.9,
                                        weight_decay=1e-4)
            scheduler = MultiStepLR(optimizer,
                                    milestones=[
                                        int(args.fine_tune_epochs * 0.3),
                                        int(args.fine_tune_epochs * 0.6),
                                        int(args.fine_tune_epochs * 0.8)
                                    ],
                                    gamma=0.1)
        else:
            raise ValueError("Pruner not supported.")
        best_acc = 0
        for epoch in range(args.fine_tune_epochs):
            print('start fine tune for epoch {}/{}'.format(
                epoch, args.fine_tune_epochs))
            stime = time.time()
            train(args, model, device, train_loader, criterion, optimizer,
                  epoch, logger)
            scheduler.step()
            acc = evaluator(model, epoch)
            print('end fine tune for epoch {}/{} for {} seconds'.format(
                epoch, args.fine_tune_epochs,
                time.time() - stime))
            if acc > best_acc and args.rank == 0:
                best_acc = acc
                torch.save(
                    model,
                    os.path.join(args.experiment_data_dir, args.model,
                                 'finetune_model.pt'))
                torch.save(
                    model.state_dict(),
                    os.path.join(args.experiment_data_dir,
                                 'model_fine_tuned.pth'))

    print('Evaluation result (fine tuned): %s' % best_acc)
    print('Fined tuned model saved to %s', args.experiment_data_dir)
    result['performance']['finetuned'] = best_acc

    if args.rank == 0:
        with open(os.path.join(args.experiment_data_dir, 'result.json'),
                  'w+') as f:
            json.dump(result, f)