def _test(epoch_length=None): max_epochs = 10 num_iters = 21 torch.manual_seed(0) data = torch.randint(0, 1000, size=(num_iters, )) if epoch_length is None: epoch_length = num_iters for resume_epoch in range(1, max_epochs): batch_checker = BatchChecker(data, init_counter=resume_epoch * epoch_length) def update_fn(_, batch): assert batch_checker.check( batch ), f"{resume_epoch} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}" engine = DeterministicEngine(update_fn) resume_state_dict = dict(epoch=resume_epoch, max_epochs=max_epochs, epoch_length=epoch_length, rng_states=None) engine.load_state_dict(resume_state_dict) engine.run(data) assert engine.state.epoch == max_epochs assert engine.state.iteration == epoch_length * max_epochs
def _test(epoch_length=None): max_epochs = 5 num_iters = 21 torch.manual_seed(0) data = torch.randint(0, 1000, size=(num_iters, )) if epoch_length is None: epoch_length = num_iters for resume_iteration in range( 2, min(num_iters * max_epochs, epoch_length * max_epochs), 4): batch_checker = BatchChecker(data, init_counter=resume_iteration) def update_fn(_, batch): assert batch_checker.check( batch ), f"{resume_iteration} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}" engine = DeterministicEngine(update_fn) @engine.on(Events.EPOCH_COMPLETED) def check_iteration(_): assert engine.state.iteration == batch_checker.counter resume_state_dict = dict(iteration=resume_iteration, max_epochs=max_epochs, epoch_length=epoch_length, rng_states=None) engine.load_state_dict(resume_state_dict) engine.run(data) assert engine.state.epoch == max_epochs assert engine.state.iteration == epoch_length * max_epochs
def _test(epoch_length=None): max_epochs = 3 batch_size = 4 num_iters = 17 def infinite_data_iterator(): while True: for _ in range(num_iters): data = torch.randint(0, 1000, size=(batch_size, ), device=device) yield data if epoch_length is None: epoch_length = num_iters for resume_iteration in range( 1, min(num_iters * max_epochs, epoch_length * max_epochs), 7): seen_batchs = [] def update_fn(_, batch): seen_batchs.append(batch) engine = DeterministicEngine(update_fn) torch.manual_seed(24) engine.run( infinite_data_iterator(), max_epochs=max_epochs, epoch_length=epoch_length, ) batch_checker = BatchChecker(seen_batchs, init_counter=resume_iteration) def update_fn(_, batch): assert batch_checker.check( batch ), f"{resume_iteration} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}" engine = DeterministicEngine(update_fn) resume_state_dict = dict(iteration=resume_iteration, max_epochs=max_epochs, epoch_length=epoch_length, rng_states=None) engine.load_state_dict(resume_state_dict) torch.manual_seed(24) engine.run(infinite_data_iterator()) assert engine.state.epoch == max_epochs assert ( engine.state.iteration == epoch_length * max_epochs ), f"{resume_iteration} | {engine.state.iteration} vs {epoch_length * max_epochs}"
def _test(epoch_length=None): max_epochs = 5 batch_size = 4 num_iters = 21 def infinite_data_iterator(): while True: for _ in range(num_iters): data = torch.randint(0, 1000, size=(batch_size, ), device=device) yield data if epoch_length is None: epoch_length = num_iters for resume_epoch in range(1, max_epochs): seen_batchs = [] def update_fn(_, batch): # if there is a random op when using data batch etc, we can not resume correctly # torch.rand(1) seen_batchs.append(batch) engine = DeterministicEngine(update_fn) torch.manual_seed(121) engine.run( infinite_data_iterator(), max_epochs=max_epochs, epoch_length=epoch_length, ) batch_checker = BatchChecker(seen_batchs, init_counter=resume_epoch * epoch_length) def update_fn(_, batch): assert batch_checker.check( batch ), f"{resume_epoch} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}" engine = DeterministicEngine(update_fn) resume_state_dict = dict(epoch=resume_epoch, max_epochs=max_epochs, epoch_length=epoch_length, rng_states=None) engine.load_state_dict(resume_state_dict) torch.manual_seed(121) engine.run(infinite_data_iterator()) assert engine.state.epoch == max_epochs assert engine.state.iteration == epoch_length * max_epochs
def test_concepts_snippet_resume(): # Commented imports required in the snippet # import torch # from torch.utils.data import DataLoader # from ignite.engine import DeterministicEngine # from ignite.utils import manual_seed seen_batches = [] manual_seed(seed=15) def random_train_data_loader(size): data = torch.arange(0, size) return DataLoader(data, batch_size=4, shuffle=True) def print_train_data(engine, batch): i = engine.state.iteration e = engine.state.epoch print("train", e, i, batch.tolist()) seen_batches.append(batch) trainer = DeterministicEngine(print_train_data) print("Original Run") manual_seed(56) trainer.run(random_train_data_loader(40), max_epochs=2, epoch_length=5) original_batches = list(seen_batches) seen_batches = [] print("Resumed Run") trainer.load_state_dict({ "epoch": 1, "epoch_length": 5, "max_epochs": 2, "rng_states": None }) manual_seed(56) trainer.run(random_train_data_loader(40)) resumed_batches = list(seen_batches) seen_batches = [] for b1, b2 in zip(original_batches[5:], resumed_batches): assert (b1 == b2).all()
def _train(save_iter=None, save_epoch=None, sd=None): w_norms = [] grad_norms = [] data = [] chkpt = [] manual_seed(12) arch = [ nn.Conv2d(3, 10, 3), nn.ReLU(), nn.Conv2d(10, 10, 3), nn.ReLU(), nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 2), ] if with_dropout: arch.insert(2, nn.Dropout2d()) arch.insert(-2, nn.Dropout()) model = nn.Sequential(*arch).to(device) opt = SGD(model.parameters(), lr=0.001) def proc_fn(e, b): from ignite.engine.deterministic import _get_rng_states, _repr_rng_state s = _repr_rng_state(_get_rng_states()) model.train() opt.zero_grad() y = model(b.to(device)) y.sum().backward() opt.step() if debug: print(trainer.state.iteration, trainer.state.epoch, "proc_fn - b.shape", b.shape, torch.norm(y).item(), s) trainer = DeterministicEngine(proc_fn) if save_iter is not None: ev = Events.ITERATION_COMPLETED(once=save_iter) elif save_epoch is not None: ev = Events.EPOCH_COMPLETED(once=save_epoch) save_iter = save_epoch * (data_size // batch_size) @trainer.on(ev) def save_chkpt(_): if debug: print(trainer.state.iteration, "save_chkpt") fp = dirname / "test.pt" from ignite.engine.deterministic import _repr_rng_state tsd = trainer.state_dict() if debug: print("->", _repr_rng_state(tsd["rng_states"])) torch.save([model.state_dict(), opt.state_dict(), tsd], fp) chkpt.append(fp) def log_event_filter(_, event): if (event // save_iter == 1) and 1 <= (event % save_iter) <= 5: return True return False @trainer.on(Events.ITERATION_COMPLETED(event_filter=log_event_filter)) def write_data_grads_weights(e): x = e.state.batch i = e.state.iteration data.append([i, x.mean().item(), x.std().item()]) total = [0.0, 0.0] out1 = [] out2 = [] for p in model.parameters(): n1 = torch.norm(p).item() n2 = torch.norm(p.grad).item() out1.append(n1) out2.append(n2) total[0] += n1 total[1] += n2 w_norms.append([i, total[0]] + out1) grad_norms.append([i, total[1]] + out2) if sd is not None: sd = torch.load(sd) model.load_state_dict(sd[0]) opt.load_state_dict(sd[1]) from ignite.engine.deterministic import _repr_rng_state if debug: print("-->", _repr_rng_state(sd[2]["rng_states"])) trainer.load_state_dict(sd[2]) manual_seed(32) trainer.run(random_train_data_loader(size=data_size), max_epochs=5) return { "sd": chkpt, "data": data, "grads": grad_norms, "weights": w_norms }
def _test(epoch_length=None): max_epochs = 3 total_batch_size = 4 num_iters = 17 torch.manual_seed(0) data = torch.randint(0, 1000, size=(num_iters * total_batch_size, )) if epoch_length is None: epoch_length = num_iters for resume_iteration in range( 2, min(num_iters * max_epochs, epoch_length * max_epochs), 13): for num_workers in [0, 2]: sampler, batch_size = _setup_sampler(sampler_type, num_iters, total_batch_size) orig_dataloader = DataLoader( data, batch_size=batch_size, num_workers=num_workers, pin_memory="cuda" in torch.device(device).type, sampler=sampler, drop_last=True, shuffle=sampler is None, ) seen_batchs = [] def update_fn(_, batch): batch_to_device = batch.to(device) seen_batchs.append(batch) engine = DeterministicEngine(update_fn) if sampler_type == "distributed": @engine.on(Events.EPOCH_STARTED) def _(engine): sampler.set_epoch(engine.state.epoch) torch.manual_seed(12) engine.run(orig_dataloader, max_epochs=max_epochs, epoch_length=epoch_length) batch_checker = BatchChecker(seen_batchs, init_counter=resume_iteration) sampler, batch_size = _setup_sampler(sampler_type, num_iters, total_batch_size) resume_dataloader = DataLoader( data, batch_size=batch_size, num_workers=num_workers, pin_memory="cuda" in torch.device(device).type, sampler=sampler, drop_last=True, shuffle=sampler is None, ) def update_fn(_, batch): batch_to_device = batch.to(device) cfg_msg = f"{num_workers} {resume_iteration}" assert batch_checker.check( batch ), f"{cfg_msg} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}" engine = DeterministicEngine(update_fn) if sampler_type == "distributed": @engine.on(Events.EPOCH_STARTED) def _(engine): sampler.set_epoch(engine.state.epoch) resume_state_dict = dict(iteration=resume_iteration, max_epochs=max_epochs, epoch_length=epoch_length, rng_states=None) engine.load_state_dict(resume_state_dict) torch.manual_seed(12) engine.run(resume_dataloader) assert engine.state.epoch == max_epochs assert ( engine.state.iteration == epoch_length * max_epochs ), f"{num_workers}, {resume_iteration} | {engine.state.iteration} vs {epoch_length * max_epochs}"
def _test(epoch_length=None): max_epochs = 5 total_batch_size = 4 num_iters = 21 torch.manual_seed(0) data = torch.randint(0, 1000, size=(num_iters * total_batch_size, )) if epoch_length is None: epoch_length = num_iters for resume_epoch in range(1, max_epochs, 2): for num_workers in [0, 4]: sampler, batch_size = _setup_sampler(sampler_type, num_iters, total_batch_size) orig_dataloader = DataLoader( data, batch_size=batch_size, num_workers=num_workers, pin_memory="cuda" in device, sampler=sampler, drop_last=True, shuffle=sampler is None, ) seen_batchs = [] def update_fn(_, batch): batch_to_device = batch.to(device) seen_batchs.append(batch) engine = DeterministicEngine(update_fn) if sampler_type == "distributed": @engine.on(Events.EPOCH_STARTED) def _(engine): sampler.set_epoch(engine.state.epoch - 1) torch.manual_seed(87) engine.run( orig_dataloader, max_epochs=max_epochs, epoch_length=epoch_length, ) batch_checker = BatchChecker(seen_batchs, init_counter=resume_epoch * epoch_length) sampler, batch_size = _setup_sampler(sampler_type, num_iters, total_batch_size) resume_dataloader = DataLoader( data, batch_size=batch_size, num_workers=num_workers, pin_memory="cuda" in device, sampler=sampler, drop_last=True, shuffle=sampler is None, ) def update_fn(_, batch): batch_to_device = batch.to(device) assert batch_checker.check( batch), "{} {} | {}: {} vs {}".format( num_workers, resume_epoch, batch_checker.counter, batch_checker.true_batch, batch) engine = DeterministicEngine(update_fn) if sampler_type == "distributed": @engine.on(Events.EPOCH_STARTED) def _(engine): sampler.set_epoch(engine.state.epoch - 1) resume_state_dict = dict(epoch=resume_epoch, max_epochs=max_epochs, epoch_length=epoch_length, rng_states=None) engine.load_state_dict(resume_state_dict) torch.manual_seed(87) engine.run(resume_dataloader) assert engine.state.epoch == max_epochs assert engine.state.iteration == epoch_length * max_epochs