예제 #1
0
def build_optimizer(optimizer_config, parameters):
    if optimizer_config.type == 'sgd':
        optimizer = torch.optim.SGD(parameters,
                                    optimizer_config.lr,
                                    momentum=optimizer_config.sgd.momentum,
                                    weight_decay=optimizer_config.weight_decay,
                                    nesterov=True)
    elif optimizer_config.type == 'rmsprop':
        optimizer = torch.optim.RMSprop(
            parameters,
            optimizer_config.lr,
            momentum=optimizer_config.rmsprop.momentum,
            weight_decay=optimizer_config.weight_decay)
    elif optimizer_config.type == 'radam':
        optimizer = RAdam(parameters,
                          optimizer_config.lr,
                          weight_decay=optimizer_config.weight_decay)
    else:
        raise AssertionError('invalid OPT {}'.format(optimizer_config.type))

    if optimizer_config.lookahead is not None:
        optimizer = optim.LA(optimizer,
                             optimizer_config.lookahead.lr,
                             num_steps=optimizer_config.lookahead.steps)

    if optimizer_config.ewa is not None:
        optimizer = optim.EWA(optimizer,
                              optimizer_config.ewa.momentum,
                              num_steps=optimizer_config.ewa.steps)
    else:
        optimizer = optim.DummySwitchable(optimizer)

    return optimizer
예제 #2
0
from pprint import pprint as print

import torch
import torch.nn as nn

import optim

m = nn.Linear(1, 1)
opt = torch.optim.Adam(m.parameters())
opt = optim.LA(opt, 0.5, 5)
opt = optim.EWA(opt, 0.9, 10)

print(opt.state_dict())

opt.train()
m.train()
for _ in range(10):
    m(torch.ones(10, 1)).mean().backward()
    opt.step()

# print('>' * 100)
with open('s1.txt', 'w') as f:
    print(opt.state_dict(), stream=f)

state_dict = opt.state_dict()
opt = torch.optim.Adam(m.parameters())
opt = optim.LA(opt, 0.5, 5)
opt = optim.EWA(opt, 0.9, 10)
opt.load_state_dict(state_dict)

# print('>' * 100)