示例#1
0
def test_perturb_model2():
    with t.no_grad():
        seed = 0
        model = nn.Linear(2, 2, bias=False)
        optim = t.optim.Adam(model.parameters(), 1e-3)
        model.weight.fill_(1)
        weight_no_noise = t.ones([2, 2])
        weight_stepped = t.tensor([[1.1530995369, 0.9696571231],
                                   [0.7811210752, 1.0558431149]])
        model_input = t.ones([1, 2])
        output_no_noise = t.full([1, 2], 2.0)
        output_with_noise = t.tensor([[2.1247568130, 1.8389642239]])
        output_with_noise2 = t.tensor([[1.8739618063, 1.9643428326]])

    p_switch = Switch()
    r_switch = Switch()
    t.manual_seed(seed)

    def gen_func(shape, device, std_dev):
        gen = NormalNoiseGen(shape)
        return gen(device) * std_dev

    cancel = perturb_model(model,
                           p_switch,
                           r_switch,
                           noise_generate_function=gen_func,
                           debug_backward=True)

    p_switch.on()
    r_switch.on()
    # p-on, r-on
    assert _t_eq_eps(output_with_noise, model(model_input))
    p_switch.off()
    # p-off, r-on
    # will adjust noise parameters
    assert _t_eq_eps(output_no_noise, model(model_input))
    assert _t_eq_eps(model.weight, weight_no_noise)
    p_switch.on()
    r_switch.off()
    # p-on, r-off
    assert _t_eq_eps(output_with_noise, model(model_input))
    r_switch.on()
    # p-on, r-on
    action = model(model_input)
    assert _t_eq_eps(output_with_noise2, action)

    loss = (action - t.ones_like(action)).sum()
    loss.backward()
    optim.step()
    assert _t_eq_eps(model.weight, weight_stepped)

    cancel()
示例#2
0
文件: tes_pm.py 项目: mrshenli/machin
from machin.utils.helper_classes import Switch
from machin.frame.noise.param_space_noise import perturb_model
from machin.utils.visualize import visualize_graph
import torch as t

dims = 5

t.manual_seed(0)
model = t.nn.Linear(dims, dims)
optim = t.optim.Adam(model.parameters(), 1e-3)
p_switch, r_switch = Switch(), Switch()
rst_func = perturb_model(model, p_switch, r_switch)
r_switch.on()

# turn off/on the perturbation switch to see the difference
p_switch.on()

# do some sampling
action = model(t.ones([dims]))

# Visualize will not show any leaf noise tensors
# because they are created in t.no_grad() context.
visualize_graph(action, exit_after_vis=False)

# do some training
loss = (action - t.ones([dims])).sum()
loss.backward()
rst_func()
optim.step()
print(model.weight)