Ejemplo n.º 1
0
    def __init__(self,
                 model,
                 lr=0.1,
                 hook_enabled=True,
                 factor_decay=0.95,
                 damping=0.001,
                 kl_clip=0.001,
                 fac_update_freq=10,
                 kfac_update_freq=100,
                 batch_averaged=True,
                 diag_blocks=1,
                 diag_warmup=0,
                 distribute_layer_factors=None,
                 sparse=False,
                 sparse_ratio=0.01,
                 exclude_parts=''):
        #exclude_parts='CommunicateInverse,ComputeInverse,CommunicateFactor,ComputeFactor'):

        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 < factor_decay <= 1:
            raise ValueError(
                "Invalid factor decay rate: {}".format(factor_decay))
        if not 0.0 < damping:
            raise ValueError("Invalid damping: {}".format(damping))
        if not 0.0 < kl_clip:
            raise ValueError("Invalid clipping value: {}".format(kl_clip))
        if not 0 < fac_update_freq:
            raise ValueError(
                "Invalid factor update frequency: {}".format(fac_update_freq))
        if not 0 < kfac_update_freq:
            raise ValueError(
                "Invalid K-FAC update frequency: {}".format(kfac_update_freq))
        if not 0 == kfac_update_freq % fac_update_freq:
            print(
                "WARNING: it is suggested that kfac_update_freq be a multiple of fac_update_freq"
            )
        if not 0 < diag_blocks:
            raise ValueError(
                "Invalid diagonal block approx count: {}".format(diag_blocks))
        if not 0 <= diag_blocks:
            raise ValueError(
                "Invalid diagonal block approx count: {}".format(diag_blocks))
        if not 1 == diag_blocks:
            print(
                "WARNING: diag_blocks > 1 is experimental and may give poor results."
            )

        # For compatibility with `KFACParamScheduler`
        defaults = dict(lr=lr,
                        damping=damping,
                        fac_update_freq=fac_update_freq,
                        kfac_update_freq=kfac_update_freq)

        super(KFAC, self).__init__(model.parameters(), defaults)

        self.computeA = ComputeA()
        self.computeG = ComputeG()
        self.known_modules = {'Linear', 'Conv2d'}
        self.modules = []
        self.module_names = []
        # register hooks for known modules
        self.hook_enabled = hook_enabled
        self._register_modules(model)

        # tcmm communicator
        self.communicator = tcmm.Communicator(hvd.rank(), hvd.size(), 1)

        self.steps = 0

        # Dictionaries keyed by `module` to storing the factors and inverse factors
        self.m_a, self.m_g = {}, {}
        self.m_A, self.m_G = {}, {}
        self.m_inv_A, self.m_inv_G = {}, {}
        self.module_ranks = None

        self.sparse = sparse
        self.sparse_ratio = sparse_ratio
        self.residualsA, self.residualsG = {}, {}

        self.factor_decay = factor_decay
        self.kl_clip = kl_clip
        self.fac_update_freq = fac_update_freq
        self.kfac_update_freq = kfac_update_freq
        self.diag_blocks = diag_blocks
        self.diag_warmup = diag_warmup
        self.batch_averaged = batch_averaged

        self.exclude_communicate_inverse = True if exclude_parts.find(
            'CommunicateInverse') >= 0 else False
        self.exclude_compute_inverse = True if exclude_parts.find(
            'ComputeInverse') >= 0 else False
        self.exclude_communicate_factor = True if exclude_parts.find(
            'CommunicateFactor') >= 0 else False
        self.exclude_compute_factor = True if exclude_parts.find(
            'ComputeFactor') >= 0 else False

        # Compute ideal value for `distribute_layer_factors` based on
        # registered module count
        if distribute_layer_factors is None:
            self.distribute_layer_factors = True \
                    if hvd.size() > len(self.modules) else False
        else:
            self.distribute_layer_factors = distribute_layer_factors

        self.eps = 1e-10  # for numerical stability
        self.rank_iter = cycle(list(range(hvd.size())))
Ejemplo n.º 2
0
    def __init__(self,
                 model,
                 lr=0.1,
                 factor_decay=0.95,
                 damping=0.001,
                 kl_clip=0.001,
                 fac_update_freq=10,
                 kfac_update_freq=100,
                 batch_averaged=True,
                 diag_blocks=1,
                 diag_warmup=0,
                 distribute_layer_factors=None,
                 gradient_clip="agc"):

        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= factor_decay <= 1:
            raise ValueError(
                "Invalid factor decay rate: {}".format(factor_decay))
        if not 0.0 < damping:
            raise ValueError("Invalid damping: {}".format(damping))
        if not 0.0 < kl_clip:
            raise ValueError("Invalid clipping value: {}".format(kl_clip))
        if not 0 < fac_update_freq:
            raise ValueError(
                "Invalid factor update frequency: {}".format(fac_update_freq))
        if not 0 < kfac_update_freq:
            raise ValueError(
                "Invalid K-FAC update frequency: {}".format(kfac_update_freq))
        if not 0 == kfac_update_freq % fac_update_freq:
            print(
                "WARNING: it is suggested that kfac_update_freq be a multiple of fac_update_freq"
            )
        if not 0 < diag_blocks:
            raise ValueError(
                "Invalid diagonal block approx count: {}".format(diag_blocks))
        if not 0 <= diag_blocks:
            raise ValueError(
                "Invalid diagonal block approx count: {}".format(diag_blocks))
        if not 1 == diag_blocks:
            print(
                "WARNING: diag_blocks > 1 is experimental and may give poor results."
            )

        # For compatibility with `KFACParamScheduler`
        #   defaults – (dict): a dict containing default values of optimization options (used when a parameter group doesn’t specify them).
        defaults = dict(lr=lr,
                        damping=damping,
                        fac_update_freq=fac_update_freq,
                        kfac_update_freq=kfac_update_freq,
                        gradient_clip=gradient_clip)

        super(KFAC, self).__init__(model.parameters(), defaults)

        self.computeA = ComputeA()
        self.computeG = ComputeG()
        self.known_modules = {'Linear', 'Conv2d', 'BertLayerNorm0'}
        self.modules = []
        self._register_modules(model)

        self.steps = 0
        self.gradient_clip = gradient_clip  #"agc"

        # Dictionaries keyed by `module` to storing the factors and
        # eigendecompositions
        self.m_a, self.m_g = {}, {}
        self.m_A, self.m_G = {}, {}
        self.m_QA, self.m_QG = {}, {}
        self.m_dA, self.m_dG = {}, {}

        self.factor_decay = factor_decay
        self.kl_clip = kl_clip
        self.fac_update_freq = fac_update_freq
        self.kfac_update_freq = kfac_update_freq
        self.diag_blocks = diag_blocks
        self.diag_warmup = diag_warmup
        self.batch_averaged = batch_averaged
        self.hvd_size = 1  #hvd.size()

        # Compute ideal value for `distribute_layer_factors` based on
        # registered module count
        if distribute_layer_factors is None:
            self.distribute_layer_factors = True \
                    if hvd.size() > len(self.modules) else False
        else:
            self.distribute_layer_factors = distribute_layer_factors

        self.have_cleared_Q = True if self.diag_warmup == 0 else False
        self.eps = 1e-10  # for numerical stability
        self.rank_iter = cycle(list(range(self.hvd_size)))
        self.T_all = 0
    def __init__(self,
                 model,
                 lr=0.1,
                 factor_decay=0.95,
                 damping=0.001,
                 kl_clip=0.001,
                 fac_update_freq=10,
                 kfac_update_freq=100,
                 batch_averaged=True,
                 diag_blocks=1,
                 diag_warmup=0,
                 distribute_layer_factors=None,
                 sparse=False,
                 sparse_ratio=0.01,
                 exclude_parts=''):

        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 < factor_decay <= 1:
            raise ValueError(
                "Invalid factor decay rate: {}".format(factor_decay))
        if not 0.0 < damping:
            raise ValueError("Invalid damping: {}".format(damping))
        if not 0.0 < kl_clip:
            raise ValueError("Invalid clipping value: {}".format(kl_clip))
        if not 0 < fac_update_freq:
            raise ValueError(
                "Invalid factor update frequency: {}".format(fac_update_freq))
        if not 0 < kfac_update_freq:
            raise ValueError(
                "Invalid K-FAC update frequency: {}".format(kfac_update_freq))
        if not 0 == kfac_update_freq % fac_update_freq:
            print(
                "WARNING: it is suggested that kfac_update_freq be a multiple of fac_update_freq"
            )
        if not 0 < diag_blocks:
            raise ValueError(
                "Invalid diagonal block approx count: {}".format(diag_blocks))
        if not 0 <= diag_blocks:
            raise ValueError(
                "Invalid diagonal block approx count: {}".format(diag_blocks))
        if not 1 == diag_blocks:
            print(
                "WARNING: diag_blocks > 1 is experimental and may give poor results."
            )

        # For compatibility with `KFACParamScheduler`
        defaults = dict(lr=lr,
                        damping=damping,
                        fac_update_freq=fac_update_freq,
                        kfac_update_freq=kfac_update_freq)

        super(KFAC, self).__init__(model.parameters(), defaults)

        self.computeA = ComputeA()
        self.computeG = ComputeG()
        self.known_modules = {'Linear', 'Conv2d'}
        self.modules = []
        self.module_names = []
        self.name_module_map = {}
        self.module_name_map = {}
        #self.fw_factor_handles = []
        #self.bw_factor_handles = []
        self._register_modules(model)
        self.fw_merged_comm = MergedCommAllReduce(self.module_names,
                                                  prefix='forward',
                                                  merge=True,
                                                  single_layer=False)
        self.bw_merged_comm = MergedCommAllReduce(self.module_names,
                                                  prefix='backward',
                                                  merge=True,
                                                  single_layer=False)

        self.steps = 0

        # Dictionaries keyed by `module` to storing the factors and
        # eigendecompositions
        self.m_a, self.m_g = {}, {}
        self.m_A, self.m_G = {}, {}
        self.m_QA, self.m_QG = {}, {}
        self.m_dA, self.m_dG = {}, {}
        self.m_dA_ranks = {}
        self.m_dG_ranks = {}
        self.module_ranks = None

        self.sparse = sparse
        self.sparse_ratio = sparse_ratio
        self.residualsA, self.residualsG = {}, {}

        self.factor_decay = factor_decay
        self.kl_clip = kl_clip
        self.fac_update_freq = fac_update_freq
        self.kfac_update_freq = kfac_update_freq
        self.diag_blocks = diag_blocks
        self.diag_warmup = diag_warmup
        self.batch_averaged = batch_averaged

        # Compute ideal value for `distribute_layer_factors` based on
        # registered module count
        if distribute_layer_factors is None:
            self.distribute_layer_factors = True \
                    if hvd.size() > len(self.modules) else False
        else:
            self.distribute_layer_factors = distribute_layer_factors

        self.have_cleared_Q = True if self.diag_warmup == 0 else False
        self.eps = 1e-10  # for numerical stability
        self.rank_iter = cycle(list(range(hvd.size())))