Пример #1
0
def dict_config(model, optimizer, sched_dict, scheduler=None):
    app_cfg_logger.debug('Schedule contents:\n' +
                         json.dumps(sched_dict, indent=2))

    if scheduler is None:
        scheduler = distiller.CompressionScheduler(model)

    pruners = __factory('pruners', model, sched_dict)
    regularizers = __factory('regularizers', model, sched_dict)
    quantizers = __factory('quantizers',
                           model,
                           sched_dict,
                           optimizer=optimizer)
    if len(quantizers) > 1:
        raise ValueError("\nError: Multiple Quantizers not supported")
    extensions = __factory('extensions', model, sched_dict)

    try:
        lr_policies = []
        for policy_def in sched_dict['policies']:
            policy = None
            if 'pruner' in policy_def:
                try:
                    instance_name, args = __policy_params(policy_def, 'pruner')
                except TypeError as e:
                    print(
                        '\n\nFatal Error: a policy is defined with a null pruner'
                    )
                    print(
                        'Here\'s the policy definition for your reference:\n{}'
                        .format(json.dumps(policy_def, indent=1)))
                    raise
                assert instance_name in pruners, "Pruner {} was not defined in the list of pruners".format(
                    instance_name)
                pruner = pruners[instance_name]
                policy = distiller.PruningPolicy(pruner, args)

            elif 'regularizer' in policy_def:
                instance_name, args = __policy_params(policy_def,
                                                      'regularizer')
                assert instance_name in regularizers, "Regularizer {} was not defined in the list of regularizers".format(
                    instance_name)
                regularizer = regularizers[instance_name]
                if args is None:
                    policy = distiller.RegularizationPolicy(regularizer)
                else:
                    policy = distiller.RegularizationPolicy(
                        regularizer, **args)

            elif 'quantizer' in policy_def:
                instance_name, args = __policy_params(policy_def, 'quantizer')
                assert instance_name in quantizers, "Quantizer {} was not defined in the list of quantizers".format(
                    instance_name)
                quantizer = quantizers[instance_name]
                policy = distiller.QuantizationPolicy(quantizer)

            elif 'lr_scheduler' in policy_def:
                # LR schedulers take an optimizer in their CTOR, so postpone handling until we're certain
                # a quantization policy was initialized (if exists)
                lr_policies.append(policy_def)
                continue

            elif 'extension' in policy_def:
                instance_name, args = __policy_params(policy_def, 'extension')
                assert instance_name in extensions, "Extension {} was not defined in the list of extensions".format(
                    instance_name)
                extension = extensions[instance_name]
                policy = extension

            else:
                raise ValueError(
                    "\nFATAL Parsing error while parsing the pruning schedule - unknown policy [%s]"
                    .format(policy_def))

            add_policy_to_scheduler(policy, policy_def, scheduler)

        # Any changes to the optmizer caused by a quantizer have occured by now, so safe to create LR schedulers
        lr_schedulers = __factory('lr_schedulers',
                                  model,
                                  sched_dict,
                                  optimizer=optimizer)
        for policy_def in lr_policies:
            instance_name, args = __policy_params(policy_def, 'lr_scheduler')
            assert instance_name in lr_schedulers, "LR-scheduler {} was not defined in the list of lr-schedulers".format(
                instance_name)
            lr_scheduler = lr_schedulers[instance_name]
            policy = distiller.LRPolicy(lr_scheduler)
            add_policy_to_scheduler(policy, policy_def, scheduler)

    except AssertionError:
        # propagate the assertion information
        raise
    except Exception as exception:
        print("\nFATAL Parsing error!\n%s" % json.dumps(policy_def, indent=1))
        print("Exception: %s %s" % (type(exception), exception))
        raise
    return scheduler
Пример #2
0
def dictConfig(model, optimizer, schedule, sched_dict, logger):
    logger.debug(json.dumps(sched_dict, indent=1))

    pruners = __factory('pruners', model, sched_dict)
    regularizers = __factory('regularizers', model, sched_dict)
    lr_schedulers = __factory('lr_schedulers',
                              model,
                              sched_dict,
                              optimizer=optimizer)
    extensions = __factory('extensions', model, sched_dict)

    try:
        for policy_def in sched_dict['policies']:
            policy = None
            if 'pruner' in policy_def:
                try:
                    instance_name, args = __policy_params(policy_def, 'pruner')
                except TypeError as e:
                    print(
                        '\n\nFatal Error: a policy is defined with a null pruner'
                    )
                    print(
                        'Here\'s the policy definition for your reference:\n{}'
                        .format(json.dumps(policy_def, indent=1)))
                    exit(1)
                assert instance_name in pruners, "Pruner {} was not defined in the list of pruners".format(
                    instance_name)
                pruner = pruners[instance_name]
                policy = distiller.PruningPolicy(pruner, args)

            elif 'regularizer' in policy_def:
                instance_name, args = __policy_params(policy_def,
                                                      'regularizer')
                assert instance_name in regularizers, "Regularizer {} was not defined in the list of regularizers".format(
                    instance_name)
                regularizer = regularizers[instance_name]
                if args is None:
                    policy = distiller.RegularizationPolicy(regularizer)
                else:
                    policy = distiller.RegularizationPolicy(
                        regularizer, **args)

            elif 'lr_scheduler' in policy_def:
                instance_name, args = __policy_params(policy_def,
                                                      'lr_scheduler')
                assert instance_name in lr_schedulers, "LR-scheduler {} was not defined in the list of lr-schedulers".format(
                    instance_name)
                lr_scheduler = lr_schedulers[instance_name]
                policy = distiller.LRPolicy(lr_scheduler)

            elif 'extension' in policy_def:
                instance_name, args = __policy_params(policy_def, 'extension')
                assert instance_name in extensions, "Extension {} was not defined in the list of extensions".format(
                    instance_name)
                extension = extensions[instance_name]
                policy = extension

            else:
                print(
                    "\nFATAL Parsing error while parsing the pruning schedule - unknown policy [%s]"
                    % policy_def)
                exit(1)

            if 'epochs' in policy_def:
                schedule.add_policy(policy, epochs=policy_def['epochs'])
            else:
                schedule.add_policy(
                    policy,
                    starting_epoch=policy_def['starting_epoch'],
                    ending_epoch=policy_def['ending_epoch'],
                    frequency=policy_def['frequency'])
    except AssertionError:
        # propagate the assertion information
        raise
    except Exception as exception:
        print("\nFATAL Parsing error!\n%s" % json.dumps(policy_def, indent=1))
        print("Exception: %s %s" % (type(exception), exception))
        exit(1)

    return schedule
Пример #3
0
def objective(space):
    global model
    global count
    global global_min_score
    
    #Explore new model
    model = create_model(False, args.dataset, args.arch, device_ids=args.gpus)
    count += 1
    # Objective function: F(Acc, Lat) = (1 - Acc.) + (alpha * Sparsity)
    accuracy = 0
    alpha = 0.3 # Super-parameter: the importance of inference time
    latency = 0.0
    sparsity = 0.0
    # Training hyperparameter

    if args.resume:
        model, compression_scheduler, start_epoch = apputils.load_checkpoint(
            model, chkpt_file=args.resume)
        print('resume mode: {}'.format(args.resume))

    print(global_min_score)
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    """
    distiller/distiller/config.py
        # Element-wise sparsity
        sparsity_levels = {net_param: sparsity_level}
        pruner = distiller.pruning.SparsityLevelParameterPruner(name='sensitivity', levels=sparsity_levels)
        policy = distiller.PruningPolicy(pruner, pruner_args=None)
        scheduler = distiller.CompressionScheduler(model)
        scheduler.add_policy(policy, epochs=[0, 2, 4])
        # Local search 
        add multiple pruner for each layer
    """
    sparsity_levels = {}
    for key, value in space.items():
        sparsity_levels[key] = value
    #print(sparsity_levels)

    pruner = distiller.pruning.SparsityLevelParameterPruner(name='sensitivity', levels=sparsity_levels) # for SparsityLevelParameterPruner
    # pruner = distiller.pruning.SensitivityPruner(name='sensitivity', sensitivities=sparsity_levels) # for SensitivityPruner
    policy = distiller.PruningPolicy(pruner, pruner_args=None)
    lrpolicy = distiller.LRPolicy(torch.optim.lr_scheduler.StepLR(optimizer, step_size=6, gamma=0.1))
    compression_scheduler = distiller.CompressionScheduler(model)
    compression_scheduler.add_policy(policy, epochs=[PrunerEpoch])
    # compression_scheduler.add_policy(policy, starting_epoch=0, ending_epoch=38, frequency=2)
    compression_scheduler.add_policy(lrpolicy, starting_epoch=0, ending_epoch=50, frequency=1)
    """
    distiller/example/classifier_compression/compress_classifier.py
    For each epoch:
        compression_scheduler.on_epoch_begin(epoch)
        train()
        save_checkpoint()
        compression_scheduler.on_epoch_end(epoch)

    train():
        For each training step:
            compression_scheduler.on_minibatch_begin(epoch)
            output = model(input)
            loss = criterion(output, target)
            compression_scheduler.before_backward_pass(epoch)
            loss.backward()
            optimizer.step()
            compression_scheduler.on_minibatch_end(epoch)
    """
    
    local_min_score = 2.
    for i in range(args.epochs):
        compression_scheduler.on_epoch_begin(i)
        train_accuracy = train(i,criterion, optimizer, compression_scheduler)
        val_accuracy = validate() # Validate hyperparameter setting
        t, sparsity = distiller.weights_sparsity_tbl_summary(model, return_total_sparsity=True)
        compression_scheduler.on_epoch_end(i, optimizer)
        apputils.save_checkpoint(i, args.arch, model, optimizer, compression_scheduler, train_accuracy, False,
                                         'hyperopt', './')
        print('Epoch: {}, train_acc: {:.4f}, val_acc: {:.4f}, sparsity: {:.4f}'.format(i, train_accuracy, val_accuracy, sparsity))
        
        score = (1-(val_accuracy/100.)) + (alpha * (1-sparsity/100.)) # objective funtion here
        if(score < global_min_score):
            global_min_score = score
            apputils.save_checkpoint(i, args.arch, model, optimizer, compression_scheduler, train_accuracy, True, 'best', './')

        if(score < local_min_score):
            local_min_score = score

        if (PrunerConstraint == True and i >= PrunerEpoch and (sparsity < Expected_Sparsity_Level_Low or sparsity > Expected_Sparsity_Level_High)):
            break 

    test_accuracy = test() # Validate hyperparameter setting

    print('{} trials: score: {:.4f}, train_acc:{:.4f}, val_acc:{:.4f}, test_acc:{:.4f}, sparsity:{:.4f}'.format(count, 
                                      local_min_score, 
                                      train_accuracy, 
                                      val_accuracy, 
                                      test_accuracy,
                                      sparsity))

    return local_min_score
Пример #4
0
def objective(space):
    global model
    global count
    global best_dict
    
    #Explore new model
    model = create_model(False, args.dataset, args.arch, device_ids=args.gpus)
    if args.resume:
        model, _, _ = apputils.load_checkpoint(
            model, chkpt_file=args.resume)
    
    count += 1
    print('{} trial starting...'.format(count))
    # Objective function: F(Acc, Lat) = (1 - Acc.) + (alpha * Sparsity)
    accuracy = 0
    #alpha = 0.2 # Super-parameter: the importance of inference time
    alpha = 1.0 # Super-parameter: the importance of inference time
    sparsity = 0.0
    # Training hyperparameter
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    """
    distiller/distiller/config.py
        # Element-wise sparsity
        sparsity_levels = {net_param: sparsity_level}
        pruner = distiller.pruning.SparsityLevelParameterPruner(name='sensitivity', levels=sparsity_levels)
        policy = distiller.PruningPolicy(pruner, pruner_args=None)
        scheduler = distiller.CompressionScheduler(model)
        scheduler.add_policy(policy, epochs=[0, 2, 4])
        # Local search 
        add multiple pruner for each layer
    """
    sparsity_levels = {}
    for key, value in space.items():
        sparsity_levels[key] = value
    pruner = distiller.pruning.SparsityLevelParameterPruner(name='sensitivity', levels=sparsity_levels)
    policy = distiller.PruningPolicy(pruner, pruner_args=None)
    lrpolicy = distiller.LRPolicy(torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1))
    compression_scheduler = distiller.CompressionScheduler(model)
    #compression_scheduler.add_policy(policy, epochs=[90])
    compression_scheduler.add_policy(policy, epochs=[0])
    compression_scheduler.add_policy(lrpolicy, starting_epoch=0, ending_epoch=90, frequency=1)
    """
    distiller/example/classifier_compression/compress_classifier.py
    For each epoch:
        compression_scheduler.on_epoch_begin(epoch)
        train()
        save_checkpoint()
        compression_scheduler.on_epoch_end(epoch)

    train():
        For each training step:
            compression_scheduler.on_minibatch_begin(epoch)
            output = model(input)
            loss = criterion(output, target)
            compression_scheduler.before_backward_pass(epoch)
            loss.backward()
            optimizer.step()
            compression_scheduler.on_minibatch_end(epoch)
    """
    for i in range(args.epochs):
        compression_scheduler.on_epoch_begin(i)
        train_accuracy = train(i,criterion, optimizer, compression_scheduler)
        val_accuracy = validate() # Validate hyperparameter setting
        t, sparsity = distiller.weights_sparsity_tbl_summary(model, return_total_sparsity=True)
        compression_scheduler.on_epoch_end(i, optimizer)
        apputils.save_checkpoint(i, args.arch, model, optimizer, compression_scheduler, train_accuracy, False,
                                         'hyperopt', './')
        print('{} epochs => train acc:{:.2f}%,  val acc:{:.2f}%'.format(i, train_accuracy, val_accuracy))
        
    test_accuracy = validate(test_loader) # Validate hyperparameter setting
    #score = (1-(val_accuracy/100.)) + (alpha * (1-sparsity/100.)) # objective funtion here
    
    # objective funtion here
    # accuracy: 98~90%, sparsity: 80%~50%
    score = -((val_accuracy/100.)**2-0.9**2 + alpha * ((sparsity/100.)**2-0.5**2)) 
    print('{} trials: score: {:.2f}\ttrain acc:{:.2f}%\tval acc:{:.2f}%\ttest acc:{:.2f}%\tsparsity:{:.2f}%'.format(count, 
                                      score, 
                                      train_accuracy, 
                                      val_accuracy, 
                                      test_accuracy,
                                      sparsity))
    if score < best_dict['score']:
        best_dict['trial'] = count
        best_dict['score'] = score
        best_dict['tr_acc'] = train_accuracy        
        best_dict['v_acc'] = val_accuracy
        best_dict['te_acc'] = test_accuracy
        best_dict['sparsity'] = sparsity
        best_dict['model_best'] = copy.deepcopy(model)

    return score