def parse(schedule_str): try: schedule = LearningRateSchedulerFixedStep.parse_schedule_str(schedule_str) except ValueError: raise argparse.ArgumentTypeError( "Learning rate schedule string should have form rate1:num_updates1[,rate2:num_updates2,...]") return schedule
def test_fixed_step_lr_scheduler(): # Parse schedule string schedule_str = "0.5:16,0.25:8" schedule = LearningRateSchedulerFixedStep.parse_schedule_str(schedule_str) assert schedule == [(0.5, 16), (0.25, 8)] # Check learning rate steps updates_per_checkpoint = 2 scheduler = LearningRateSchedulerFixedStep(schedule, updates_per_checkpoint) t = 0 for _ in range(16): t += 1 assert scheduler(t) == 0.5 if t % 2 == 0: scheduler.new_evaluation_result(False) assert scheduler(t) == 0.25 for _ in range(8): t += 1 assert scheduler(t) == 0.25 if t % 2 == 0: scheduler.new_evaluation_result(False)