def __init__(self, model, lr=0.1, hook_enabled=True, factor_decay=0.95, damping=0.001, kl_clip=0.001, fac_update_freq=10, kfac_update_freq=100, batch_averaged=True, diag_blocks=1, diag_warmup=0, distribute_layer_factors=None, sparse=False, sparse_ratio=0.01, exclude_parts=''): #exclude_parts='CommunicateInverse,ComputeInverse,CommunicateFactor,ComputeFactor'): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 < factor_decay <= 1: raise ValueError( "Invalid factor decay rate: {}".format(factor_decay)) if not 0.0 < damping: raise ValueError("Invalid damping: {}".format(damping)) if not 0.0 < kl_clip: raise ValueError("Invalid clipping value: {}".format(kl_clip)) if not 0 < fac_update_freq: raise ValueError( "Invalid factor update frequency: {}".format(fac_update_freq)) if not 0 < kfac_update_freq: raise ValueError( "Invalid K-FAC update frequency: {}".format(kfac_update_freq)) if not 0 == kfac_update_freq % fac_update_freq: print( "WARNING: it is suggested that kfac_update_freq be a multiple of fac_update_freq" ) if not 0 < diag_blocks: raise ValueError( "Invalid diagonal block approx count: {}".format(diag_blocks)) if not 0 <= diag_blocks: raise ValueError( "Invalid diagonal block approx count: {}".format(diag_blocks)) if not 1 == diag_blocks: print( "WARNING: diag_blocks > 1 is experimental and may give poor results." ) # For compatibility with `KFACParamScheduler` defaults = dict(lr=lr, damping=damping, fac_update_freq=fac_update_freq, kfac_update_freq=kfac_update_freq) super(KFAC, self).__init__(model.parameters(), defaults) self.computeA = ComputeA() self.computeG = ComputeG() self.known_modules = {'Linear', 'Conv2d'} self.modules = [] self.module_names = [] # register hooks for known modules self.hook_enabled = hook_enabled self._register_modules(model) # tcmm communicator self.communicator = tcmm.Communicator(hvd.rank(), hvd.size(), 1) self.steps = 0 # Dictionaries keyed by `module` to storing the factors and inverse factors self.m_a, self.m_g = {}, {} self.m_A, self.m_G = {}, {} self.m_inv_A, self.m_inv_G = {}, {} self.module_ranks = None self.sparse = sparse self.sparse_ratio = sparse_ratio self.residualsA, self.residualsG = {}, {} self.factor_decay = factor_decay self.kl_clip = kl_clip self.fac_update_freq = fac_update_freq self.kfac_update_freq = kfac_update_freq self.diag_blocks = diag_blocks self.diag_warmup = diag_warmup self.batch_averaged = batch_averaged self.exclude_communicate_inverse = True if exclude_parts.find( 'CommunicateInverse') >= 0 else False self.exclude_compute_inverse = True if exclude_parts.find( 'ComputeInverse') >= 0 else False self.exclude_communicate_factor = True if exclude_parts.find( 'CommunicateFactor') >= 0 else False self.exclude_compute_factor = True if exclude_parts.find( 'ComputeFactor') >= 0 else False # Compute ideal value for `distribute_layer_factors` based on # registered module count if distribute_layer_factors is None: self.distribute_layer_factors = True \ if hvd.size() > len(self.modules) else False else: self.distribute_layer_factors = distribute_layer_factors self.eps = 1e-10 # for numerical stability self.rank_iter = cycle(list(range(hvd.size())))
def __init__(self, model, lr=0.1, factor_decay=0.95, damping=0.001, kl_clip=0.001, fac_update_freq=10, kfac_update_freq=100, batch_averaged=True, diag_blocks=1, diag_warmup=0, distribute_layer_factors=None, gradient_clip="agc"): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= factor_decay <= 1: raise ValueError( "Invalid factor decay rate: {}".format(factor_decay)) if not 0.0 < damping: raise ValueError("Invalid damping: {}".format(damping)) if not 0.0 < kl_clip: raise ValueError("Invalid clipping value: {}".format(kl_clip)) if not 0 < fac_update_freq: raise ValueError( "Invalid factor update frequency: {}".format(fac_update_freq)) if not 0 < kfac_update_freq: raise ValueError( "Invalid K-FAC update frequency: {}".format(kfac_update_freq)) if not 0 == kfac_update_freq % fac_update_freq: print( "WARNING: it is suggested that kfac_update_freq be a multiple of fac_update_freq" ) if not 0 < diag_blocks: raise ValueError( "Invalid diagonal block approx count: {}".format(diag_blocks)) if not 0 <= diag_blocks: raise ValueError( "Invalid diagonal block approx count: {}".format(diag_blocks)) if not 1 == diag_blocks: print( "WARNING: diag_blocks > 1 is experimental and may give poor results." ) # For compatibility with `KFACParamScheduler` # defaults – (dict): a dict containing default values of optimization options (used when a parameter group doesn’t specify them). defaults = dict(lr=lr, damping=damping, fac_update_freq=fac_update_freq, kfac_update_freq=kfac_update_freq, gradient_clip=gradient_clip) super(KFAC, self).__init__(model.parameters(), defaults) self.computeA = ComputeA() self.computeG = ComputeG() self.known_modules = {'Linear', 'Conv2d', 'BertLayerNorm0'} self.modules = [] self._register_modules(model) self.steps = 0 self.gradient_clip = gradient_clip #"agc" # Dictionaries keyed by `module` to storing the factors and # eigendecompositions self.m_a, self.m_g = {}, {} self.m_A, self.m_G = {}, {} self.m_QA, self.m_QG = {}, {} self.m_dA, self.m_dG = {}, {} self.factor_decay = factor_decay self.kl_clip = kl_clip self.fac_update_freq = fac_update_freq self.kfac_update_freq = kfac_update_freq self.diag_blocks = diag_blocks self.diag_warmup = diag_warmup self.batch_averaged = batch_averaged self.hvd_size = 1 #hvd.size() # Compute ideal value for `distribute_layer_factors` based on # registered module count if distribute_layer_factors is None: self.distribute_layer_factors = True \ if hvd.size() > len(self.modules) else False else: self.distribute_layer_factors = distribute_layer_factors self.have_cleared_Q = True if self.diag_warmup == 0 else False self.eps = 1e-10 # for numerical stability self.rank_iter = cycle(list(range(self.hvd_size))) self.T_all = 0
def __init__(self, model, lr=0.1, factor_decay=0.95, damping=0.001, kl_clip=0.001, fac_update_freq=10, kfac_update_freq=100, batch_averaged=True, diag_blocks=1, diag_warmup=0, distribute_layer_factors=None, sparse=False, sparse_ratio=0.01, exclude_parts=''): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 < factor_decay <= 1: raise ValueError( "Invalid factor decay rate: {}".format(factor_decay)) if not 0.0 < damping: raise ValueError("Invalid damping: {}".format(damping)) if not 0.0 < kl_clip: raise ValueError("Invalid clipping value: {}".format(kl_clip)) if not 0 < fac_update_freq: raise ValueError( "Invalid factor update frequency: {}".format(fac_update_freq)) if not 0 < kfac_update_freq: raise ValueError( "Invalid K-FAC update frequency: {}".format(kfac_update_freq)) if not 0 == kfac_update_freq % fac_update_freq: print( "WARNING: it is suggested that kfac_update_freq be a multiple of fac_update_freq" ) if not 0 < diag_blocks: raise ValueError( "Invalid diagonal block approx count: {}".format(diag_blocks)) if not 0 <= diag_blocks: raise ValueError( "Invalid diagonal block approx count: {}".format(diag_blocks)) if not 1 == diag_blocks: print( "WARNING: diag_blocks > 1 is experimental and may give poor results." ) # For compatibility with `KFACParamScheduler` defaults = dict(lr=lr, damping=damping, fac_update_freq=fac_update_freq, kfac_update_freq=kfac_update_freq) super(KFAC, self).__init__(model.parameters(), defaults) self.computeA = ComputeA() self.computeG = ComputeG() self.known_modules = {'Linear', 'Conv2d'} self.modules = [] self.module_names = [] self.name_module_map = {} self.module_name_map = {} #self.fw_factor_handles = [] #self.bw_factor_handles = [] self._register_modules(model) self.fw_merged_comm = MergedCommAllReduce(self.module_names, prefix='forward', merge=True, single_layer=False) self.bw_merged_comm = MergedCommAllReduce(self.module_names, prefix='backward', merge=True, single_layer=False) self.steps = 0 # Dictionaries keyed by `module` to storing the factors and # eigendecompositions self.m_a, self.m_g = {}, {} self.m_A, self.m_G = {}, {} self.m_QA, self.m_QG = {}, {} self.m_dA, self.m_dG = {}, {} self.m_dA_ranks = {} self.m_dG_ranks = {} self.module_ranks = None self.sparse = sparse self.sparse_ratio = sparse_ratio self.residualsA, self.residualsG = {}, {} self.factor_decay = factor_decay self.kl_clip = kl_clip self.fac_update_freq = fac_update_freq self.kfac_update_freq = kfac_update_freq self.diag_blocks = diag_blocks self.diag_warmup = diag_warmup self.batch_averaged = batch_averaged # Compute ideal value for `distribute_layer_factors` based on # registered module count if distribute_layer_factors is None: self.distribute_layer_factors = True \ if hvd.size() > len(self.modules) else False else: self.distribute_layer_factors = distribute_layer_factors self.have_cleared_Q = True if self.diag_warmup == 0 else False self.eps = 1e-10 # for numerical stability self.rank_iter = cycle(list(range(hvd.size())))