def dict_config(model, optimizer, sched_dict, scheduler=None): app_cfg_logger.debug('Schedule contents:\n' + json.dumps(sched_dict, indent=2)) if scheduler is None: scheduler = distiller.CompressionScheduler(model) pruners = __factory('pruners', model, sched_dict) regularizers = __factory('regularizers', model, sched_dict) quantizers = __factory('quantizers', model, sched_dict, optimizer=optimizer) if len(quantizers) > 1: raise ValueError("\nError: Multiple Quantizers not supported") extensions = __factory('extensions', model, sched_dict) try: lr_policies = [] for policy_def in sched_dict['policies']: policy = None if 'pruner' in policy_def: try: instance_name, args = __policy_params(policy_def, 'pruner') except TypeError as e: print( '\n\nFatal Error: a policy is defined with a null pruner' ) print( 'Here\'s the policy definition for your reference:\n{}' .format(json.dumps(policy_def, indent=1))) raise assert instance_name in pruners, "Pruner {} was not defined in the list of pruners".format( instance_name) pruner = pruners[instance_name] policy = distiller.PruningPolicy(pruner, args) elif 'regularizer' in policy_def: instance_name, args = __policy_params(policy_def, 'regularizer') assert instance_name in regularizers, "Regularizer {} was not defined in the list of regularizers".format( instance_name) regularizer = regularizers[instance_name] if args is None: policy = distiller.RegularizationPolicy(regularizer) else: policy = distiller.RegularizationPolicy( regularizer, **args) elif 'quantizer' in policy_def: instance_name, args = __policy_params(policy_def, 'quantizer') assert instance_name in quantizers, "Quantizer {} was not defined in the list of quantizers".format( instance_name) quantizer = quantizers[instance_name] policy = distiller.QuantizationPolicy(quantizer) elif 'lr_scheduler' in policy_def: # LR schedulers take an optimizer in their CTOR, so postpone handling until we're certain # a quantization policy was initialized (if exists) lr_policies.append(policy_def) continue elif 'extension' in policy_def: instance_name, args = __policy_params(policy_def, 'extension') assert instance_name in extensions, "Extension {} was not defined in the list of extensions".format( instance_name) extension = extensions[instance_name] policy = extension else: raise ValueError( "\nFATAL Parsing error while parsing the pruning schedule - unknown policy [%s]" .format(policy_def)) add_policy_to_scheduler(policy, policy_def, scheduler) # Any changes to the optmizer caused by a quantizer have occured by now, so safe to create LR schedulers lr_schedulers = __factory('lr_schedulers', model, sched_dict, optimizer=optimizer) for policy_def in lr_policies: instance_name, args = __policy_params(policy_def, 'lr_scheduler') assert instance_name in lr_schedulers, "LR-scheduler {} was not defined in the list of lr-schedulers".format( instance_name) lr_scheduler = lr_schedulers[instance_name] policy = distiller.LRPolicy(lr_scheduler) add_policy_to_scheduler(policy, policy_def, scheduler) except AssertionError: # propagate the assertion information raise except Exception as exception: print("\nFATAL Parsing error!\n%s" % json.dumps(policy_def, indent=1)) print("Exception: %s %s" % (type(exception), exception)) raise return scheduler
def dictConfig(model, optimizer, schedule, sched_dict, logger): logger.debug(json.dumps(sched_dict, indent=1)) pruners = __factory('pruners', model, sched_dict) regularizers = __factory('regularizers', model, sched_dict) lr_schedulers = __factory('lr_schedulers', model, sched_dict, optimizer=optimizer) extensions = __factory('extensions', model, sched_dict) try: for policy_def in sched_dict['policies']: policy = None if 'pruner' in policy_def: try: instance_name, args = __policy_params(policy_def, 'pruner') except TypeError as e: print( '\n\nFatal Error: a policy is defined with a null pruner' ) print( 'Here\'s the policy definition for your reference:\n{}' .format(json.dumps(policy_def, indent=1))) exit(1) assert instance_name in pruners, "Pruner {} was not defined in the list of pruners".format( instance_name) pruner = pruners[instance_name] policy = distiller.PruningPolicy(pruner, args) elif 'regularizer' in policy_def: instance_name, args = __policy_params(policy_def, 'regularizer') assert instance_name in regularizers, "Regularizer {} was not defined in the list of regularizers".format( instance_name) regularizer = regularizers[instance_name] if args is None: policy = distiller.RegularizationPolicy(regularizer) else: policy = distiller.RegularizationPolicy( regularizer, **args) elif 'lr_scheduler' in policy_def: instance_name, args = __policy_params(policy_def, 'lr_scheduler') assert instance_name in lr_schedulers, "LR-scheduler {} was not defined in the list of lr-schedulers".format( instance_name) lr_scheduler = lr_schedulers[instance_name] policy = distiller.LRPolicy(lr_scheduler) elif 'extension' in policy_def: instance_name, args = __policy_params(policy_def, 'extension') assert instance_name in extensions, "Extension {} was not defined in the list of extensions".format( instance_name) extension = extensions[instance_name] policy = extension else: print( "\nFATAL Parsing error while parsing the pruning schedule - unknown policy [%s]" % policy_def) exit(1) if 'epochs' in policy_def: schedule.add_policy(policy, epochs=policy_def['epochs']) else: schedule.add_policy( policy, starting_epoch=policy_def['starting_epoch'], ending_epoch=policy_def['ending_epoch'], frequency=policy_def['frequency']) except AssertionError: # propagate the assertion information raise except Exception as exception: print("\nFATAL Parsing error!\n%s" % json.dumps(policy_def, indent=1)) print("Exception: %s %s" % (type(exception), exception)) exit(1) return schedule