예제 #1
0
파일: alias.py 프로젝트: ksachdeva/optax
def _scale_by_learning_rate(learning_rate: ScalarOrSchedule):
    if callable(learning_rate):
        return transform.scale_by_schedule(lambda count: -learning_rate(count))
    return transform.scale(-learning_rate)
예제 #2
0
파일: alias.py 프로젝트: n2cholas/optax
def _scale_by_learning_rate(learning_rate: ScalarOrSchedule, flip_sign=True):
  m = -1 if flip_sign else 1
  if callable(learning_rate):
    return transform.scale_by_schedule(lambda count: m * learning_rate(count))
  return transform.scale(m * learning_rate)