def test_perturb_model2(): with t.no_grad(): seed = 0 model = nn.Linear(2, 2, bias=False) optim = t.optim.Adam(model.parameters(), 1e-3) model.weight.fill_(1) weight_no_noise = t.ones([2, 2]) weight_stepped = t.tensor([[1.1530995369, 0.9696571231], [0.7811210752, 1.0558431149]]) model_input = t.ones([1, 2]) output_no_noise = t.full([1, 2], 2.0) output_with_noise = t.tensor([[2.1247568130, 1.8389642239]]) output_with_noise2 = t.tensor([[1.8739618063, 1.9643428326]]) p_switch = Switch() r_switch = Switch() t.manual_seed(seed) def gen_func(shape, device, std_dev): gen = NormalNoiseGen(shape) return gen(device) * std_dev cancel = perturb_model(model, p_switch, r_switch, noise_generate_function=gen_func, debug_backward=True) p_switch.on() r_switch.on() # p-on, r-on assert _t_eq_eps(output_with_noise, model(model_input)) p_switch.off() # p-off, r-on # will adjust noise parameters assert _t_eq_eps(output_no_noise, model(model_input)) assert _t_eq_eps(model.weight, weight_no_noise) p_switch.on() r_switch.off() # p-on, r-off assert _t_eq_eps(output_with_noise, model(model_input)) r_switch.on() # p-on, r-on action = model(model_input) assert _t_eq_eps(output_with_noise2, action) loss = (action - t.ones_like(action)).sum() loss.backward() optim.step() assert _t_eq_eps(model.weight, weight_stepped) cancel()
from machin.utils.helper_classes import Switch from machin.frame.noise.param_space_noise import perturb_model from machin.utils.visualize import visualize_graph import torch as t dims = 5 t.manual_seed(0) model = t.nn.Linear(dims, dims) optim = t.optim.Adam(model.parameters(), 1e-3) p_switch, r_switch = Switch(), Switch() rst_func = perturb_model(model, p_switch, r_switch) r_switch.on() # turn off/on the perturbation switch to see the difference p_switch.on() # do some sampling action = model(t.ones([dims])) # Visualize will not show any leaf noise tensors # because they are created in t.no_grad() context. visualize_graph(action, exit_after_vis=False) # do some training loss = (action - t.ones([dims])).sum() loss.backward() rst_func() optim.step() print(model.weight)