Exemple #1
0
def dict_config(model, optimizer, sched_dict, scheduler=None):
    app_cfg_logger.debug('Schedule contents:\n' +
                         json.dumps(sched_dict, indent=2))

    if scheduler is None:
        scheduler = distiller.CompressionScheduler(model)

    pruners = __factory('pruners', model, sched_dict)
    regularizers = __factory('regularizers', model, sched_dict)
    quantizers = __factory('quantizers',
                           model,
                           sched_dict,
                           optimizer=optimizer)
    if len(quantizers) > 1:
        raise ValueError("\nError: Multiple Quantizers not supported")
    extensions = __factory('extensions', model, sched_dict)

    try:
        lr_policies = []
        for policy_def in sched_dict['policies']:
            policy = None
            if 'pruner' in policy_def:
                try:
                    instance_name, args = __policy_params(policy_def, 'pruner')
                except TypeError as e:
                    print(
                        '\n\nFatal Error: a policy is defined with a null pruner'
                    )
                    print(
                        'Here\'s the policy definition for your reference:\n{}'
                        .format(json.dumps(policy_def, indent=1)))
                    raise
                assert instance_name in pruners, "Pruner {} was not defined in the list of pruners".format(
                    instance_name)
                pruner = pruners[instance_name]
                policy = distiller.PruningPolicy(pruner, args)

            elif 'regularizer' in policy_def:
                instance_name, args = __policy_params(policy_def,
                                                      'regularizer')
                assert instance_name in regularizers, "Regularizer {} was not defined in the list of regularizers".format(
                    instance_name)
                regularizer = regularizers[instance_name]
                if args is None:
                    policy = distiller.RegularizationPolicy(regularizer)
                else:
                    policy = distiller.RegularizationPolicy(
                        regularizer, **args)

            elif 'quantizer' in policy_def:
                instance_name, args = __policy_params(policy_def, 'quantizer')
                assert instance_name in quantizers, "Quantizer {} was not defined in the list of quantizers".format(
                    instance_name)
                quantizer = quantizers[instance_name]
                policy = distiller.QuantizationPolicy(quantizer)

            elif 'lr_scheduler' in policy_def:
                # LR schedulers take an optimizer in their CTOR, so postpone handling until we're certain
                # a quantization policy was initialized (if exists)
                lr_policies.append(policy_def)
                continue

            elif 'extension' in policy_def:
                instance_name, args = __policy_params(policy_def, 'extension')
                assert instance_name in extensions, "Extension {} was not defined in the list of extensions".format(
                    instance_name)
                extension = extensions[instance_name]
                policy = extension

            else:
                raise ValueError(
                    "\nFATAL Parsing error while parsing the pruning schedule - unknown policy [%s]"
                    .format(policy_def))

            add_policy_to_scheduler(policy, policy_def, scheduler)

        # Any changes to the optmizer caused by a quantizer have occured by now, so safe to create LR schedulers
        lr_schedulers = __factory('lr_schedulers',
                                  model,
                                  sched_dict,
                                  optimizer=optimizer)
        for policy_def in lr_policies:
            instance_name, args = __policy_params(policy_def, 'lr_scheduler')
            assert instance_name in lr_schedulers, "LR-scheduler {} was not defined in the list of lr-schedulers".format(
                instance_name)
            lr_scheduler = lr_schedulers[instance_name]
            policy = distiller.LRPolicy(lr_scheduler)
            add_policy_to_scheduler(policy, policy_def, scheduler)

    except AssertionError:
        # propagate the assertion information
        raise
    except Exception as exception:
        print("\nFATAL Parsing error!\n%s" % json.dumps(policy_def, indent=1))
        print("Exception: %s %s" % (type(exception), exception))
        raise
    return scheduler
Exemple #2
0
def dictConfig(model, optimizer, schedule, sched_dict, logger):
    logger.debug(json.dumps(sched_dict, indent=1))

    pruners = __factory('pruners', model, sched_dict)
    regularizers = __factory('regularizers', model, sched_dict)
    lr_schedulers = __factory('lr_schedulers',
                              model,
                              sched_dict,
                              optimizer=optimizer)
    extensions = __factory('extensions', model, sched_dict)

    try:
        for policy_def in sched_dict['policies']:
            policy = None
            if 'pruner' in policy_def:
                try:
                    instance_name, args = __policy_params(policy_def, 'pruner')
                except TypeError as e:
                    print(
                        '\n\nFatal Error: a policy is defined with a null pruner'
                    )
                    print(
                        'Here\'s the policy definition for your reference:\n{}'
                        .format(json.dumps(policy_def, indent=1)))
                    exit(1)
                assert instance_name in pruners, "Pruner {} was not defined in the list of pruners".format(
                    instance_name)
                pruner = pruners[instance_name]
                policy = distiller.PruningPolicy(pruner, args)

            elif 'regularizer' in policy_def:
                instance_name, args = __policy_params(policy_def,
                                                      'regularizer')
                assert instance_name in regularizers, "Regularizer {} was not defined in the list of regularizers".format(
                    instance_name)
                regularizer = regularizers[instance_name]
                if args is None:
                    policy = distiller.RegularizationPolicy(regularizer)
                else:
                    policy = distiller.RegularizationPolicy(
                        regularizer, **args)

            elif 'lr_scheduler' in policy_def:
                instance_name, args = __policy_params(policy_def,
                                                      'lr_scheduler')
                assert instance_name in lr_schedulers, "LR-scheduler {} was not defined in the list of lr-schedulers".format(
                    instance_name)
                lr_scheduler = lr_schedulers[instance_name]
                policy = distiller.LRPolicy(lr_scheduler)

            elif 'extension' in policy_def:
                instance_name, args = __policy_params(policy_def, 'extension')
                assert instance_name in extensions, "Extension {} was not defined in the list of extensions".format(
                    instance_name)
                extension = extensions[instance_name]
                policy = extension

            else:
                print(
                    "\nFATAL Parsing error while parsing the pruning schedule - unknown policy [%s]"
                    % policy_def)
                exit(1)

            if 'epochs' in policy_def:
                schedule.add_policy(policy, epochs=policy_def['epochs'])
            else:
                schedule.add_policy(
                    policy,
                    starting_epoch=policy_def['starting_epoch'],
                    ending_epoch=policy_def['ending_epoch'],
                    frequency=policy_def['frequency'])
    except AssertionError:
        # propagate the assertion information
        raise
    except Exception as exception:
        print("\nFATAL Parsing error!\n%s" % json.dumps(policy_def, indent=1))
        print("Exception: %s %s" % (type(exception), exception))
        exit(1)

    return schedule