def _test(epoch_length=None): max_epochs = 5 num_iters = 21 torch.manual_seed(0) data = torch.randint(0, 1000, size=(num_iters, )) if epoch_length is None: epoch_length = num_iters for resume_iteration in range( 2, min(num_iters * max_epochs, epoch_length * max_epochs), 4): batch_checker = BatchChecker(data, init_counter=resume_iteration) def update_fn(_, batch): assert batch_checker.check( batch ), f"{resume_iteration} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}" engine = DeterministicEngine(update_fn) @engine.on(Events.EPOCH_COMPLETED) def check_iteration(_): assert engine.state.iteration == batch_checker.counter resume_state_dict = dict(iteration=resume_iteration, max_epochs=max_epochs, epoch_length=epoch_length, rng_states=None) engine.load_state_dict(resume_state_dict) engine.run(data) assert engine.state.epoch == max_epochs assert engine.state.iteration == epoch_length * max_epochs
def _test(epoch_length=None): max_epochs = 10 num_iters = 21 torch.manual_seed(0) data = torch.randint(0, 1000, size=(num_iters, )) if epoch_length is None: epoch_length = num_iters for resume_epoch in range(1, max_epochs): batch_checker = BatchChecker(data, init_counter=resume_epoch * epoch_length) def update_fn(_, batch): assert batch_checker.check( batch ), f"{resume_epoch} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}" engine = DeterministicEngine(update_fn) resume_state_dict = dict(epoch=resume_epoch, max_epochs=max_epochs, epoch_length=epoch_length, rng_states=None) engine.load_state_dict(resume_state_dict) engine.run(data) assert engine.state.epoch == max_epochs assert engine.state.iteration == epoch_length * max_epochs
def _test(epoch_length=None): max_epochs = 3 batch_size = 4 num_iters = 17 def infinite_data_iterator(): while True: for _ in range(num_iters): data = torch.randint(0, 1000, size=(batch_size, ), device=device) yield data if epoch_length is None: epoch_length = num_iters for resume_iteration in range( 1, min(num_iters * max_epochs, epoch_length * max_epochs), 7): seen_batchs = [] def update_fn(_, batch): seen_batchs.append(batch) engine = DeterministicEngine(update_fn) torch.manual_seed(24) engine.run( infinite_data_iterator(), max_epochs=max_epochs, epoch_length=epoch_length, ) batch_checker = BatchChecker(seen_batchs, init_counter=resume_iteration) def update_fn(_, batch): assert batch_checker.check( batch ), f"{resume_iteration} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}" engine = DeterministicEngine(update_fn) resume_state_dict = dict(iteration=resume_iteration, max_epochs=max_epochs, epoch_length=epoch_length, rng_states=None) engine.load_state_dict(resume_state_dict) torch.manual_seed(24) engine.run(infinite_data_iterator()) assert engine.state.epoch == max_epochs assert ( engine.state.iteration == epoch_length * max_epochs ), f"{resume_iteration} | {engine.state.iteration} vs {epoch_length * max_epochs}"
def test_run_finite_iterator_no_epoch_length(): # FR: https://github.com/pytorch/ignite/issues/871 unknown_size = 11 def finite_unk_size_data_iter(): for i in range(unknown_size): yield i bc = BatchChecker(data=list(range(unknown_size))) engine = DeterministicEngine(lambda e, b: bc.check(b)) @engine.on(Events.DATALOADER_STOP_ITERATION) def restart_iter(): engine.state.dataloader = finite_unk_size_data_iter() data_iter = finite_unk_size_data_iter() engine.run(data_iter, max_epochs=5) assert engine.state.epoch == 5 assert engine.state.iteration == unknown_size * 5
def test_run_finite_iterator_no_epoch_length_2(): # FR: https://github.com/pytorch/ignite/issues/871 known_size = 11 def finite_size_data_iter(size): for i in range(size): yield i bc = BatchChecker(data=list(range(known_size))) engine = Engine(lambda e, b: bc.check(b)) @engine.on(Events.ITERATION_COMPLETED(every=known_size)) def restart_iter(): engine.state.dataloader = finite_size_data_iter(known_size) data_iter = finite_size_data_iter(known_size) engine.run(data_iter, max_epochs=5) assert engine.state.epoch == 5 assert engine.state.iteration == known_size * 5
def _test(epoch_length=None): max_epochs = 5 batch_size = 4 num_iters = 21 def infinite_data_iterator(): while True: for _ in range(num_iters): data = torch.randint(0, 1000, size=(batch_size, ), device=device) yield data if epoch_length is None: epoch_length = num_iters for resume_epoch in range(1, max_epochs): seen_batchs = [] def update_fn(_, batch): # if there is a random op when using data batch etc, we can not resume correctly # torch.rand(1) seen_batchs.append(batch) engine = DeterministicEngine(update_fn) torch.manual_seed(121) engine.run( infinite_data_iterator(), max_epochs=max_epochs, epoch_length=epoch_length, ) batch_checker = BatchChecker(seen_batchs, init_counter=resume_epoch * epoch_length) def update_fn(_, batch): assert batch_checker.check( batch ), f"{resume_epoch} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}" engine = DeterministicEngine(update_fn) resume_state_dict = dict(epoch=resume_epoch, max_epochs=max_epochs, epoch_length=epoch_length, rng_states=None) engine.load_state_dict(resume_state_dict) torch.manual_seed(121) engine.run(infinite_data_iterator()) assert engine.state.epoch == max_epochs assert engine.state.iteration == epoch_length * max_epochs
def test_run_once_finite_iterator_no_epoch_length(): # FR: https://github.com/pytorch/ignite/issues/871 unknown_size = 11 def finite_unk_size_data_iter(): for i in range(unknown_size): yield i bc = BatchChecker(data=list(range(unknown_size))) engine = Engine(lambda e, b: bc.check(b)) completed_handler = MagicMock() engine.add_event_handler(Events.COMPLETED, completed_handler) data_iter = finite_unk_size_data_iter() engine.run(data_iter) assert engine.state.epoch == 1 assert engine.state.iteration == unknown_size assert completed_handler.call_count == 1
def _test_as_iter(data, max_epochs, num_iters): batch_checker = BatchChecker(data) def update_fn(_, batch): assert batch_checker.check(batch), "{}: {} vs {}".format( batch_checker.counter, batch_checker.true_batch, batch) engine = Engine(update_fn) engine.run(iter(data), max_epochs=max_epochs, epoch_length=num_iters) if num_iters is None: num_iters = len(data) assert engine.state.iteration == num_iters * max_epochs assert engine.state.epoch == max_epochs
def _test(data, max_epochs, num_iters): batch_checker = BatchChecker(data) def update_fn(_, batch): assert batch_checker.check( batch ), f"{batch_checker.counter}: {batch_checker.true_batch} vs {batch}" engine = Engine(update_fn) engine.run(data, max_epochs=max_epochs, epoch_length=num_iters) if num_iters is None: num_iters = len(data) assert engine.state.iteration == num_iters * max_epochs assert engine.state.epoch == max_epochs
def _test(epoch_length=None): max_epochs = 3 total_batch_size = 4 num_iters = 17 torch.manual_seed(0) data = torch.randint(0, 1000, size=(num_iters * total_batch_size, )) if epoch_length is None: epoch_length = num_iters for resume_iteration in range( 2, min(num_iters * max_epochs, epoch_length * max_epochs), 13): for num_workers in [0, 2]: sampler, batch_size = _setup_sampler(sampler_type, num_iters, total_batch_size) orig_dataloader = DataLoader( data, batch_size=batch_size, num_workers=num_workers, pin_memory="cuda" in torch.device(device).type, sampler=sampler, drop_last=True, shuffle=sampler is None, ) seen_batchs = [] def update_fn(_, batch): batch_to_device = batch.to(device) seen_batchs.append(batch) engine = DeterministicEngine(update_fn) if sampler_type == "distributed": @engine.on(Events.EPOCH_STARTED) def _(engine): sampler.set_epoch(engine.state.epoch) torch.manual_seed(12) engine.run(orig_dataloader, max_epochs=max_epochs, epoch_length=epoch_length) batch_checker = BatchChecker(seen_batchs, init_counter=resume_iteration) sampler, batch_size = _setup_sampler(sampler_type, num_iters, total_batch_size) resume_dataloader = DataLoader( data, batch_size=batch_size, num_workers=num_workers, pin_memory="cuda" in torch.device(device).type, sampler=sampler, drop_last=True, shuffle=sampler is None, ) def update_fn(_, batch): batch_to_device = batch.to(device) cfg_msg = f"{num_workers} {resume_iteration}" assert batch_checker.check( batch ), f"{cfg_msg} | {batch_checker.counter}: {batch_checker.true_batch} vs {batch}" engine = DeterministicEngine(update_fn) if sampler_type == "distributed": @engine.on(Events.EPOCH_STARTED) def _(engine): sampler.set_epoch(engine.state.epoch) resume_state_dict = dict(iteration=resume_iteration, max_epochs=max_epochs, epoch_length=epoch_length, rng_states=None) engine.load_state_dict(resume_state_dict) torch.manual_seed(12) engine.run(resume_dataloader) assert engine.state.epoch == max_epochs assert ( engine.state.iteration == epoch_length * max_epochs ), f"{num_workers}, {resume_iteration} | {engine.state.iteration} vs {epoch_length * max_epochs}"
def _test(epoch_length=None): max_epochs = 5 total_batch_size = 4 num_iters = 21 torch.manual_seed(0) data = torch.randint(0, 1000, size=(num_iters * total_batch_size, )) if epoch_length is None: epoch_length = num_iters for resume_epoch in range(1, max_epochs, 2): for num_workers in [0, 4]: sampler, batch_size = _setup_sampler(sampler_type, num_iters, total_batch_size) orig_dataloader = DataLoader( data, batch_size=batch_size, num_workers=num_workers, pin_memory="cuda" in device, sampler=sampler, drop_last=True, shuffle=sampler is None, ) seen_batchs = [] def update_fn(_, batch): batch_to_device = batch.to(device) seen_batchs.append(batch) engine = DeterministicEngine(update_fn) if sampler_type == "distributed": @engine.on(Events.EPOCH_STARTED) def _(engine): sampler.set_epoch(engine.state.epoch - 1) torch.manual_seed(87) engine.run( orig_dataloader, max_epochs=max_epochs, epoch_length=epoch_length, ) batch_checker = BatchChecker(seen_batchs, init_counter=resume_epoch * epoch_length) sampler, batch_size = _setup_sampler(sampler_type, num_iters, total_batch_size) resume_dataloader = DataLoader( data, batch_size=batch_size, num_workers=num_workers, pin_memory="cuda" in device, sampler=sampler, drop_last=True, shuffle=sampler is None, ) def update_fn(_, batch): batch_to_device = batch.to(device) assert batch_checker.check( batch), "{} {} | {}: {} vs {}".format( num_workers, resume_epoch, batch_checker.counter, batch_checker.true_batch, batch) engine = DeterministicEngine(update_fn) if sampler_type == "distributed": @engine.on(Events.EPOCH_STARTED) def _(engine): sampler.set_epoch(engine.state.epoch - 1) resume_state_dict = dict(epoch=resume_epoch, max_epochs=max_epochs, epoch_length=epoch_length, rng_states=None) engine.load_state_dict(resume_state_dict) torch.manual_seed(87) engine.run(resume_dataloader) assert engine.state.epoch == max_epochs assert engine.state.iteration == epoch_length * max_epochs