def save_chkpt(_): if debug: print(trainer.state.iteration, "save_chkpt") fp = dirname / "test.pt" from ignite.engine.deterministic import _repr_rng_state tsd = trainer.state_dict() if debug: print("->", _repr_rng_state(tsd["rng_states"])) torch.save([model.state_dict(), opt.state_dict(), tsd], fp) chkpt.append(fp)
def proc_fn(e, b): from ignite.engine.deterministic import _get_rng_states, _repr_rng_state s = _repr_rng_state(_get_rng_states()) model.train() opt.zero_grad() y = model(b.to(device)) y.sum().backward() opt.step() if debug: print(trainer.state.iteration, trainer.state.epoch, "proc_fn - b.shape", b.shape, torch.norm(y).item(), s)
def _train(save_iter=None, save_epoch=None, sd=None): w_norms = [] grad_norms = [] data = [] chkpt = [] manual_seed(12) arch = [ nn.Conv2d(3, 10, 3), nn.ReLU(), nn.Conv2d(10, 10, 3), nn.ReLU(), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 2), ] if with_dropout: arch.insert(2, nn.Dropout2d()) arch.insert(-2, nn.Dropout()) model = nn.Sequential(*arch).to(device) opt = SGD(model.parameters(), lr=0.001) def proc_fn(e, b): from ignite.engine.deterministic import _get_rng_states, _repr_rng_state s = _repr_rng_state(_get_rng_states()) model.train() opt.zero_grad() y = model(b.to(device)) y.sum().backward() opt.step() if debug: print(trainer.state.iteration, trainer.state.epoch, "proc_fn - b.shape", b.shape, torch.norm(y).item(), s) trainer = DeterministicEngine(proc_fn) if save_iter is not None: ev = Events.ITERATION_COMPLETED(once=save_iter) elif save_epoch is not None: ev = Events.EPOCH_COMPLETED(once=save_epoch) save_iter = save_epoch * (data_size // batch_size) @trainer.on(ev) def save_chkpt(_): if debug: print(trainer.state.iteration, "save_chkpt") fp = dirname / "test.pt" from ignite.engine.deterministic import _repr_rng_state tsd = trainer.state_dict() if debug: print("->", _repr_rng_state(tsd["rng_states"])) torch.save([model.state_dict(), opt.state_dict(), tsd], fp) chkpt.append(fp) def log_event_filter(_, event): if (event // save_iter == 1) and 1 <= (event % save_iter) <= 5: return True return False @trainer.on(Events.ITERATION_COMPLETED(event_filter=log_event_filter)) def write_data_grads_weights(e): x = e.state.batch i = e.state.iteration data.append([i, x.mean().item(), x.std().item()]) total = [0.0, 0.0] out1 = [] out2 = [] for p in model.parameters(): n1 = torch.norm(p).item() n2 = torch.norm(p.grad).item() out1.append(n1) out2.append(n2) total[0] += n1 total[1] += n2 w_norms.append([i, total[0]] + out1) grad_norms.append([i, total[1]] + out2) if sd is not None: sd = torch.load(sd) model.load_state_dict(sd[0]) opt.load_state_dict(sd[1]) from ignite.engine.deterministic import _repr_rng_state if debug: print("-->", _repr_rng_state(sd[2]["rng_states"])) trainer.load_state_dict(sd[2]) manual_seed(32) trainer.run(random_train_data_loader(size=data_size), max_epochs=5) return { "sd": chkpt, "data": data, "grads": grad_norms, "weights": w_norms }