def get_torch_lr_scheduler(optimizer, lr_scheduler_params): from torch.optim.lr_scheduler import LambdaLR, MultiStepLR, CosineAnnealingLR, CyclicLR if lr_scheduler_params['type'] == 'constant': scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 1) scheduler.step_type = 'epoch' return scheduler if lr_scheduler_params['type'] == 'multi_step': scheduler = MultiStepLR(optimizer, **lr_scheduler_params['kwargs']) scheduler.step_type = 'epoch' return scheduler if lr_scheduler_params['type'] == 'cyclic': scheduler = CyclicLR(optimizer, base_lr=0.001, max_lr=0.1, step_size_up=50, step_size_down=100, mode='triangular') scheduler.step_type = 'iter' return scheduler if lr_scheduler_params['type'] == 'cosine_annealing': scheduler = CosineAnnealingLR(optimizer, **lr_scheduler_params['kwargs']) scheduler.step_type = 'iter' return scheduler if lr_scheduler_params['type'] == 'experiment': def get_lr_factor(epoch): return 0.95**epoch scheduler = LambdaLR(optimizer, lr_lambda=get_lr_factor) scheduler.step_type = 'epoch' return scheduler