def on_epoch_start(self, state): optimizer = state.get_key(key="optimizer", inner_key=self.optimizer_key) self._optimizer_wd = [ group.get("weight_decay", 0.0) for group in optimizer.param_groups ] for i in range(len(optimizer.param_groups)): safitty.set(optimizer.param_groups, i, "weight_decay", value=0.0)
def on_epoch_end(self, state): if self.decouple_weight_decay: optimizer = state.get_key(key="optimizer", inner_key=self.optimizer_key) for i, wd in enumerate(self._optimizer_wd): safitty.set(optimizer.param_groups, i, "weight_decay", value=wd)
def test_safe_set(config): config = copy.deepcopy(config) assert safitty.get(config, "numbers", transform=len) == 5 safitty.set(config, "numbers", 8, value=42) assert safitty.get(config, "numbers", transform=len) == 9 assert safitty.get(config, "numbers", 8) == 42 assert safitty.set(config, "numbers2", "inner", value=[]) assert safitty.get(config, "numbers2", "inner", transform=len) == 0 assert safitty.set(config, "numbers", value=[]) assert safitty.get(config, "numbers", transform=len) == 0
def on_epoch_start(self, state): """On epoch start event""" optimizer = state.get_key( key="optimizer", inner_key=self.optimizer_key ) if self.decouple_weight_decay: self._optimizer_wd = [ group.get("weight_decay", 0.0) for group in optimizer.param_groups ] for i in range(len(optimizer.param_groups)): safitty.set( optimizer.param_groups, i, "weight_decay", value=0.0) else: self._optimizer_wd = [0.0] * len(optimizer.param_groups)
def set_optimizer_momentum(optimizer: Optimizer, value: float, index: int = 0): """ Set momentum of ``index``'th param group of optimizer to ``value`` Args: optimizer: PyTorch optimizer value (float): new value of momentum index (int, optional): integer index of optimizer's param groups, default is 0 """ betas = safitty.get(optimizer.param_groups, index, "betas") momentum = safitty.get(optimizer.param_groups, index, "momentum") if betas is not None: _, beta = betas safitty.set(optimizer.param_groups, index, "betas", value=(value, beta)) elif momentum is not None: safitty.set(optimizer.param_groups, index, "momentum", value=value)
def test_safe_set_2(transforms): transforms = copy.deepcopy(transforms) safitty.set(transforms, 2, "name", value="BatchNorm2d") assert safitty.get(transforms, 2, "name") == "BatchNorm2d" safitty.set(transforms, 1, "params", value="add", strategy="on_none") params1 = safitty.get(transforms, 1, "params") assert params1 is not None assert params1 != "add" safitty.set(transforms, 0, "params", value="subtract", strategy="on_none") params0 = safitty.get(transforms, 0, "params") assert params0 is not None assert params0 == "subtract"
def test_safe_set_strategies(config): config = copy.deepcopy(config) safitty.set(config, "words", "quadre", value="four", strategy="existing_key") assert safitty.get(config, "words", "quadre") is None safitty.set(config, "words", "one", value="four", strategy="existing_key") assert safitty.get(config, "words", "one") is not None assert safitty.get(config, "words", "one") == "four" safitty.set(config, "words", "one", value="five", strategy="missing_key") assert safitty.get(config, "words", "one") != "five" safitty.set(config, "words", "five", value="five", strategy="missing_key") assert safitty.get(config, "words", "five") == "five" # cannot reset a reference assert safitty.set(config, value="hi") == "hi" assert config != "hi" safitty.set(config, "numbers", 40, "hi", "", value="привет") assert safitty.get(config, "numbers", 40, "hi", "") is not None assert safitty.get(config, "numbers", 40, "hi", "") == "привет"