def build_optimizer(optimizer_config, parameters): if optimizer_config.type == 'sgd': optimizer = torch.optim.SGD(parameters, optimizer_config.lr, momentum=optimizer_config.sgd.momentum, weight_decay=optimizer_config.weight_decay, nesterov=True) elif optimizer_config.type == 'rmsprop': optimizer = torch.optim.RMSprop( parameters, optimizer_config.lr, momentum=optimizer_config.rmsprop.momentum, weight_decay=optimizer_config.weight_decay) elif optimizer_config.type == 'radam': optimizer = RAdam(parameters, optimizer_config.lr, weight_decay=optimizer_config.weight_decay) else: raise AssertionError('invalid OPT {}'.format(optimizer_config.type)) if optimizer_config.lookahead is not None: optimizer = optim.LA(optimizer, optimizer_config.lookahead.lr, num_steps=optimizer_config.lookahead.steps) if optimizer_config.ewa is not None: optimizer = optim.EWA(optimizer, optimizer_config.ewa.momentum, num_steps=optimizer_config.ewa.steps) else: optimizer = optim.DummySwitchable(optimizer) return optimizer
from pprint import pprint as print import torch import torch.nn as nn import optim m = nn.Linear(1, 1) opt = torch.optim.Adam(m.parameters()) opt = optim.LA(opt, 0.5, 5) opt = optim.EWA(opt, 0.9, 10) print(opt.state_dict()) opt.train() m.train() for _ in range(10): m(torch.ones(10, 1)).mean().backward() opt.step() # print('>' * 100) with open('s1.txt', 'w') as f: print(opt.state_dict(), stream=f) state_dict = opt.state_dict() opt = torch.optim.Adam(m.parameters()) opt = optim.LA(opt, 0.5, 5) opt = optim.EWA(opt, 0.9, 10) opt.load_state_dict(state_dict) # print('>' * 100)