def test_dataloader_restarts(): import adaptdl.checkpoint import adaptdl.collective from adaptdl.env import num_restarts, num_replicas adaptdl.collective.initialize("0.0.0.0") dataset_size = 100 init_batch_size = 10 dataset = TensorDataset(torch.rand(dataset_size)) dataloader = AdaptiveDataLoader(dataset, batch_size=init_batch_size) # Load data samples in the following order: # 2 batches (20 samples) using 1 replica (local_bsz = 10, batch_size = 10) # 5 batches (60 samples) using 4 replica (local_bsz = 3, batch_size = 12) # 2 batches (20 samples) using 2 replica (local_bsz = 5, batch_size = 10) assert current_dataloader() is None for idx, batch in enumerate(dataloader): if num_restarts() == 0 and idx == 2: adaptdl.checkpoint.save_all_states() return 4 # Restart with 4 replicas. if num_restarts() == 1 and idx == 5: adaptdl.checkpoint.save_all_states() return 2 # Restart with 2 replicas. assert current_dataloader() is dataloader._elastic local_bsz = batch[0].size(0) assert dataloader.current_local_bsz == local_bsz assert local_bsz == math.ceil(init_batch_size / num_replicas()) assert dataloader.current_batch_size == num_replicas() * local_bsz # After the last 2 batches. assert idx == 1
def test_save_load(): import pickle from adaptdl.checkpoint import (State, save_all_states, save_state, load_state) from adaptdl.env import replica_rank, num_restarts class TestState(State): def __init__(self, name): super().__init__(name) self.synced = False def sync(self): self.synced = True def save(self, fileobj): assert replica_rank() == 0 # Should only be called from rank 0. pickle.dump(self.value, fileobj) def load(self, fileobj): # Should load the correct value. self.value = pickle.load(fileobj) state_1 = TestState("state_1") state_2 = TestState("state_2") if num_restarts() == 0: # Save state with sync=True. state_1.value = 1 save_state(state_1) assert state_1.synced return 2 # Restart with 2 replicas. elif num_restarts() == 1: load_state(state_1) assert state_1.value == 1 # Save state with sync=False. state_2.value = 2 save_state(state_2, sync=False) assert not state_2.synced return 2 # Restart with 2 replicas. elif num_restarts() == 2: load_state(state_1) assert state_1.value == 1 load_state(state_2) assert state_2.value == 2 return 2 elif num_restarts() == 3: # Save all state. state_1.value = 10 state_2.value = 20 save_all_states() assert state_1.synced and state_2.synced return 2 # Restart with 2 replicas. elif num_restarts() == 4: load_state(state_1) load_state(state_2) assert state_1.value == 10 assert state_2.value == 20 else: assert False
def test_profile(num_replicas): import adaptdl.checkpoint from adaptdl.env import num_restarts from adaptdl.torch._metrics import ( profile_step_start, profile_sync_time, profile_step_commit, _metrics_state) if num_restarts() == 0: profile = _metrics_state().profile assert len(profile) == 0 # Profile local_bsz=1 but don't commit. profile_step_start(1) profile_sync_time(1.0) # Profile local_bsz=2 and commit. profile_step_start(2) profile_sync_time(1.0) profile_sync_time(2.0) profile_step_commit() # Ensure profile is updated correctly. profile = _metrics_state().profile key = (1, 1, 2) assert len(profile) == 1 assert profile[key]["accum_count"] == 0 assert profile[key]["optim_count"] == 1 assert profile[key]["optim_sync_time"] == 3.0 assert profile[key]["optim_step_time"] > 0.0 # Checkpoint and restart. adaptdl.checkpoint.save_all_states() return num_replicas elif num_restarts() == 1: profile = _metrics_state().profile # Ensure checkpoint is loaded correctly. key = (1, 1, 2) assert len(profile) == 1 assert profile[key]["accum_count"] == 0 assert profile[key]["optim_count"] == 1 assert profile[key]["optim_sync_time"] == 3.0 assert profile[key]["optim_step_time"] > 0.0 # Profile local_bsz=3 and commit twice. profile_step_start(3) profile_sync_time(2.0) profile_sync_time(3.0) profile_step_commit() key = (1, num_replicas, 3) old_step_time = profile[key]["optim_step_time"] profile_step_start(3) profile_sync_time(3.0) profile_sync_time(4.0) profile_step_commit() # Ensure profile is updated correctly. assert len(profile) == 2 assert profile[key]["accum_count"] == 0 assert profile[key]["optim_count"] == 2 assert profile[key]["optim_sync_time"] == 12.0 assert profile[key]["optim_step_time"] > old_step_time > 0.0
def test_bptt_iterator(): import adaptdl.checkpoint import adaptdl.collective from adaptdl.env import num_restarts adaptdl.collective.initialize("0.0.0.0") # Load the iterator with 500 words # 1 batch (5x10) using 1 replica. Restart after one iteration. # 9 batches (5x5) using 2 replicas to consume remaining batches. TEXT = torchtext.data.Field(tokenize=get_tokenizer("basic_english"), init_token='<sos>', eos_token='<eos>') fields = [('text', TEXT)] examples = [torchtext.data.Example.fromlist([['The'] * 500], fields)] dataset = torchtext.data.Dataset(examples, fields) TEXT.build_vocab(dataset) bptt_iter = AdaptiveBPTTIterator(dataset, batch_size=10, bptt_len=5) for idx, batch in enumerate(bptt_iter): if num_restarts() == 0 and idx == 1: assert batch.text.shape == (5, 10) adaptdl.checkpoint.save_all_states() return 2 if adaptdl.env.num_replicas() == 2: assert batch.text.shape == (5, 5) or batch.text.shape == (4, 5) if adaptdl.env.num_replicas() == 2: assert idx == 8
def test_duplicate(): from adaptdl.env import num_restarts from adaptdl.checkpoint import State state_1 = State("state_1") # noqa: F841 state_2 = State("state_2") # noqa: F841 with pytest.raises(ValueError): state_dup = State("state_1") # noqa: F841 return [2, 0][num_restarts()]
def test_accumulator_restarts(): import adaptdl.checkpoint import adaptdl.collective from adaptdl.env import num_restarts, replica_rank adaptdl.collective.initialize("0.0.0.0") accum = Accumulator() if num_restarts() == 0: accum["a"] += 15 # Idempotent. assert "a" not in accum with accum.synchronized(): assert "a" in accum assert accum["a"] == 15 assert "a" not in accum if num_restarts() == 0: accum["a"] -= 5 # Idempotent. adaptdl.checkpoint.save_all_states() return 4 # Restart with 4 replicas. if num_restarts() == 1: # Idempotent. accum.update({"a": replica_rank(), "b": replica_rank()}) assert len(accum) == 0 with accum.synchronized(): assert len(accum) == 2 assert accum["a"] == 16 assert accum["b"] == 6 assert len(accum) == 0 if num_restarts() == 1: adaptdl.checkpoint.save_all_states() return 2 # Restart with 2 replicas. if num_restarts() == 2: # Idempotent. accum -= {"b": 5, "c": 5} with accum.synchronized(): assert accum["a"] == 16 assert accum["b"] == -4 assert accum["c"] == -10 accum.clear() with accum.synchronized(): assert not accum
def load_state(state): """ Load the given `State` object from persistent storage. If the object was previously saved, then State.load will be invoked with a readable file object to load from. Arguments: state (State): `State` object to load from persistent storage. Returns: `True` if state was previously saved and `State.load` was invoked, `False` otherwise. """ if from_ray(): from ray.tune import session checkpoint_dir = session.get_session().get_checkpoint() else: checkpoint_dir = checkpoint_path() if checkpoint_dir is None: return False ckpt_dirs = os.listdir(checkpoint_dir) if not ckpt_dirs: LOG.info(f"No checkpoint found in {checkpoint_dir}.") return False latest_restart_id = 0 for dir_name in ckpt_dirs: if dir_name.startswith(CKPT_DIR_PREFIX): restart_id = int(dir_name[len(CKPT_DIR_PREFIX):]) latest_restart_id = max(latest_restart_id, restart_id) if latest_restart_id != num_restarts() - 1: LOG.warning("Cannot find checkpoint from the last restart. " f"Loading checkpoint from restart {latest_restart_id}.") ckpt_dir = os.path.join(checkpoint_dir, f"{CKPT_DIR_PREFIX}{latest_restart_id}") name = _STATES_TO_NAMES[state] state_file = os.path.join(ckpt_dir, name) if not os.path.isfile(state_file): LOG.warning(f"Cannot find state file {state_file}.") return False with open(state_file, "rb") as f: state.load(f) return True
def test_dataloader_break(): import adaptdl.checkpoint import adaptdl.collective from adaptdl.env import num_restarts if num_restarts() == 0: return 2 adaptdl.collective.initialize("0.0.0.0") dataset = TensorDataset(torch.rand(100)) dataloader = AdaptiveDataLoader(dataset, batch_size=10) # Break in the middle of the first for-loop, and start another for-loop. assert current_dataloader() is None for idx, batch in enumerate(dataloader): assert current_dataloader() is dataloader._elastic if idx == 5: break assert current_dataloader() is None for idx, batch in enumerate(dataloader): assert current_dataloader() is dataloader._elastic assert idx == 9 # Run 10 batches total.
def test_profile_accumulation(num_replicas): import adaptdl.checkpoint from adaptdl.env import num_restarts from adaptdl.torch._metrics import ( profile_step_start, profile_sync_time, profile_step_commit, _metrics_state, _fit_perf_params) if num_restarts() == 0: profile = _metrics_state().profile assert len(profile) == 0 # Profile local_bsz=1 but don't commit. profile_step_start(1) profile_sync_time(1.0) # Profile local_bsz=2 and commit. profile_step_start(2) profile_step_commit(accumulation_step=True) profile_step_start(2) profile_step_commit(accumulation_step=True) profile_step_start(2) profile_sync_time(4.0) profile_step_commit(accumulation_step=False) profile_step_start(5) profile_step_commit(accumulation_step=True) profile_step_start(5) profile_step_commit(accumulation_step=True) profile_step_start(5) profile_sync_time(6.0) profile_step_commit(accumulation_step=False) # Ensure profile is updated correctly. profile = _metrics_state().profile key = (1, 1, 2) assert len(profile) == 2 assert profile[key]["accum_count"] == 2 assert profile[key]["optim_count"] == 1 assert profile[key]["optim_sync_time"] == 4.0 assert profile[key]["accum_step_time"] > 0.0 assert profile[key]["optim_step_time"] > 0.0 profile_step_start(3) profile_step_commit(accumulation_step=True) profile_step_start(3) profile_step_commit(accumulation_step=True) # Check that fitting parameters works even # without a final accumulation_step=False commit for val in profile.values(): # Ensure step time is at least sync time. val["optim_step_time"] += val["optim_sync_time"] _fit_perf_params() # Checkpoint and restart. adaptdl.checkpoint.save_all_states() return num_replicas elif num_restarts() == 1: profile = _metrics_state().profile # Ensure checkpoint is loaded correctly. key = (1, 1, 2) assert len(profile) == 3 assert profile[key]["accum_count"] == 2 assert profile[key]["optim_count"] == 1 assert profile[key]["optim_sync_time"] == 4.0 assert profile[key]["optim_step_time"] > 0.0 # Profile local_bsz=3 and commit twice. profile_step_start(3) profile_sync_time(2.0) profile_sync_time(3.0) profile_step_commit() key = (1, num_replicas, 3) old_step_time = profile[key]["optim_step_time"] profile_step_start(3) profile_sync_time(3.0) profile_sync_time(4.0) profile_step_commit() # Ensure profile is updated correctly. if num_replicas == 1: assert len(profile) == 3 else: assert len(profile) == 4 assert profile[key]["accum_count"] == 0 if num_replicas > 1 else 2 assert profile[key]["optim_count"] == 2 assert profile[key]["optim_sync_time"] == 12.0 assert profile[key]["optim_step_time"] > old_step_time > 0.0