Ejemplo n.º 1
0
 def on_epoch_start(self, state):
     optimizer = state.get_key(key="optimizer",
                               inner_key=self.optimizer_key)
     self._optimizer_wd = [
         group.get("weight_decay", 0.0) for group in optimizer.param_groups
     ]
     for i in range(len(optimizer.param_groups)):
         safitty.set(optimizer.param_groups, i, "weight_decay", value=0.0)
Ejemplo n.º 2
0
 def on_epoch_end(self, state):
     if self.decouple_weight_decay:
         optimizer = state.get_key(key="optimizer",
                                   inner_key=self.optimizer_key)
         for i, wd in enumerate(self._optimizer_wd):
             safitty.set(optimizer.param_groups,
                         i,
                         "weight_decay",
                         value=wd)
Ejemplo n.º 3
0
def test_safe_set(config):
    config = copy.deepcopy(config)
    assert safitty.get(config, "numbers", transform=len) == 5
    safitty.set(config, "numbers", 8, value=42)
    assert safitty.get(config, "numbers", transform=len) == 9
    assert safitty.get(config, "numbers", 8) == 42

    assert safitty.set(config, "numbers2", "inner", value=[])
    assert safitty.get(config, "numbers2", "inner", transform=len) == 0
    assert safitty.set(config, "numbers", value=[])
    assert safitty.get(config, "numbers", transform=len) == 0
Ejemplo n.º 4
0
 def on_epoch_start(self, state):
     """On epoch start event"""
     optimizer = state.get_key(
         key="optimizer", inner_key=self.optimizer_key
     )
     if self.decouple_weight_decay:
         self._optimizer_wd = [
             group.get("weight_decay", 0.0)
             for group in optimizer.param_groups
         ]
         for i in range(len(optimizer.param_groups)):
             safitty.set(
                 optimizer.param_groups, i, "weight_decay", value=0.0)
     else:
         self._optimizer_wd = [0.0] * len(optimizer.param_groups)
Ejemplo n.º 5
0
def set_optimizer_momentum(optimizer: Optimizer, value: float, index: int = 0):
    """
    Set momentum of ``index``'th param group of optimizer to ``value``
    Args:
        optimizer: PyTorch optimizer
        value (float): new value of momentum
        index (int, optional): integer index of optimizer's param groups,
            default is 0
    """
    betas = safitty.get(optimizer.param_groups, index, "betas")
    momentum = safitty.get(optimizer.param_groups, index, "momentum")
    if betas is not None:
        _, beta = betas
        safitty.set(optimizer.param_groups,
                    index,
                    "betas",
                    value=(value, beta))
    elif momentum is not None:
        safitty.set(optimizer.param_groups, index, "momentum", value=value)
Ejemplo n.º 6
0
def test_safe_set_2(transforms):
    transforms = copy.deepcopy(transforms)
    safitty.set(transforms, 2, "name", value="BatchNorm2d")
    assert safitty.get(transforms, 2, "name") == "BatchNorm2d"

    safitty.set(transforms, 1, "params", value="add", strategy="on_none")
    params1 = safitty.get(transforms, 1, "params")
    assert params1 is not None
    assert params1 != "add"

    safitty.set(transforms, 0, "params", value="subtract", strategy="on_none")
    params0 = safitty.get(transforms, 0, "params")
    assert params0 is not None
    assert params0 == "subtract"
Ejemplo n.º 7
0
def test_safe_set_strategies(config):
    config = copy.deepcopy(config)
    safitty.set(config,
                "words",
                "quadre",
                value="four",
                strategy="existing_key")
    assert safitty.get(config, "words", "quadre") is None

    safitty.set(config, "words", "one", value="four", strategy="existing_key")
    assert safitty.get(config, "words", "one") is not None
    assert safitty.get(config, "words", "one") == "four"

    safitty.set(config, "words", "one", value="five", strategy="missing_key")
    assert safitty.get(config, "words", "one") != "five"

    safitty.set(config, "words", "five", value="five", strategy="missing_key")
    assert safitty.get(config, "words", "five") == "five"

    # cannot reset a reference
    assert safitty.set(config, value="hi") == "hi"
    assert config != "hi"

    safitty.set(config, "numbers", 40, "hi", "", value="привет")
    assert safitty.get(config, "numbers", 40, "hi", "") is not None
    assert safitty.get(config, "numbers", 40, "hi", "") == "привет"