def test_gather_tensors(): a = torch.zeros(1, 1) b = torch.zeros(1, 1) ab = gather([Batch(a), Batch(b)]) assert ab.size() == (2, 1)
def profile_sequential_module( module: nn.Sequential, input: Union[Tensor, Sequence[Tensor]], chunks: int, param_scale: float, device: torch.device, ) -> List[int]: """similar to 'profile_sizes' function in torchgpipe, but instead of passing in a batch of size 1, it passes in a whole batch for more accurate estimate of the sizes; moreover, it fixed the issue with negative memory allocation for latent variables reference: torchgpipe.balance.profile.profile_sizes :param module: pytorch sequential module to be profiled :type module: nn.Sequential :param input: input tensor or a sequence (will be cast to tuple) of tensors :type input: Union[Tensor, Sequence[Tensor]] :param chunks: number of chunks for a single batch specified in GPipe :type chunks: int :param param_scale: scaling factor for parameters (SGD: 2-3, Adam: 4-5, etc.); check GPipe doc for more details more details :type param_scale: float :param device: device for size profiling run; must be GPU :type device: torch.device :return: list of integers representing the sizes of all the layers in sequential model in bytes :rtype: List[int] """ if device.type != 'cuda': raise ValueError('require CUDA device for size profiler supports ' 'only CUDA device') # cast everything in the batch into a tuple of tensors if the given # input is a sequence of tensors _batch = Batch(input) if isinstance(input, Tensor) else \ Batch(tuple([_i.detach().to(device) for _i in input])) _layer_sizes_in_byte: List[int] = [] for layer in layerwise_sandbox(module, device): detach(_batch) # Detect memory usage at forward. _memory_before = torch.cuda.memory_allocated(device) _batch = _batch.call(layer) _memory_after = torch.cuda.memory_allocated(device) _latent_size = max(0, _memory_after - _memory_before) # Analyze size of parameters. param_size = sum(p.storage().size() * p.storage().element_size() for p in layer.parameters()) # Combine size of parameters and activations with normalize # scales. _size = _latent_size / chunks + param_size * param_scale _layer_sizes_in_byte.append(int(_size)) return _layer_sizes_in_byte
def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore """Performs the forward propagation. :class:`stash` or :class:`pop` commands will be handled by portals silently. The portals won't be exposed to users. Raises: RuntimeError: illegal 'stash' or 'pop' is found. """ skip_tracker = current_skip_tracker() stashed_tensors: Dict[str, Optional[Tensor]] = {} # Load skip tensors that might be popped. poppable_tensors = {} batch = Batch(input) for ns, name in self.poppable(): try: poppable_tensors[name] = skip_tracker.load(batch, ns, name) except KeyError: raise RuntimeError(f"'{name}' has not been stashed") input = batch.tensor_or_tensors # Handle skip commands. def handle_stash(name: str, tensor: Optional[Tensor]) -> None: if name not in self.stashable_names: raise RuntimeError( f"'{name}' has not been declared as stashable") stashed_tensors[name] = tensor def handle_pop(name: str) -> Optional[Tensor]: if name not in self.poppable_names: raise RuntimeError( f"'{name}' has not been declared as poppable") return poppable_tensors.pop(name) output = self.dispatch(input, handle_stash, handle_pop) # All declared skips must be stashed or popped. not_stashed = self.stashable_names - stashed_tensors.keys() if not_stashed: comma_names = ', '.join("'%s'" % n for n in not_stashed) raise RuntimeError(f'{comma_names} must be stashed but have not') not_popped = poppable_tensors.keys() if not_popped: comma_names = ', '.join("'%s'" % n for n in not_popped) raise RuntimeError(f'{comma_names} must be popped but have not') # Save stashed skip tensors. batch = Batch(output) for ns, name in self.stashable(): tensor = stashed_tensors[name] skip_tracker.save(batch, ns, name, tensor) output = batch.tensor_or_tensors return output
def test_gather_tuples(): a = (torch.zeros(1, 1), torch.zeros(2, 2)) b = (torch.zeros(1, 1), torch.zeros(2, 2)) ab = gather([Batch(a), Batch(b)]) assert isinstance(ab, tuple) assert ab[0].size() == (2, 1) assert ab[1].size() == (4, 2)
def test_batch_setitem_by_slice(): a = Batch(torch.tensor(42)) b = Batch((torch.tensor(42), torch.tensor(21))) a[:] = (torch.tensor(0), ) b[:] = (torch.tensor(0), ) assert a.atomic assert a[0].item() == 0 assert not b.atomic assert len(b) == 1 assert b[0].item() == 0
def test_batch_setitem_by_index(): a = Batch(torch.tensor(42)) b = Batch((torch.tensor(42), torch.tensor(21))) a[0] = torch.tensor(0) b[0] = torch.tensor(0) assert a.atomic assert a[0].item() == 0 assert not b.atomic assert len(b) == 2 assert b[0].item() == 0 assert b[1].item() == 21
def test_forward_lockstep(): timeline = [] class DelayedLog(nn.Module): def __init__(self, j, seconds): super().__init__() self.i = 0 self.j = j self.seconds = seconds def forward(self, x): time.sleep(self.seconds) timeline.append((self.i, self.j)) self.i += 1 return x batches = [Batch(torch.rand(1, 1)) for _ in range(3)] partitions = [ nn.Sequential(DelayedLog(0, seconds=0)), nn.Sequential(DelayedLog(1, seconds=0.1)) ] pipeline = Pipeline(batches, partitions) pipeline.run() # Expected timeline: (Logs are recorded at !) # # Partition #0: 0! 1! 2! # Partition #1: 000! 111! 222! # assert timeline == [(0, 0), (1, 0), (0, 1), (2, 0), (1, 1), (2, 1)]
def compute( batch: Batch = batch, partition: nn.Sequential = partition, skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], ) -> Batch: with use_skip_tracker(skip_tracker): return batch.call(partition)
def checkpoint(self) -> Batch: """Returns a batch applied by :class:`Checkpoint`.""" input_atomic = self.batch.atomic input = tuple(self.batch) phony = Checkpointing.phonies[self.batch[0].device] output = Checkpoint.apply(phony, self.recomputed, self.function, input_atomic, *input) return Batch(output)
def checkpoint(function: Function, input: TensorOrTensors) -> TensorOrTensors: """Makes a checkpoint with a simple interface like :func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug :class:`Checkpoint` and :class:`Recompute` without boilerplate. """ batch = Batch(input) chk = Checkpointing(function, batch) batch = chk.checkpoint() chk.recompute(batch) return batch.tensor_or_tensors
def checkpoint(self) -> Batch: """Returns a batch applied by :class:`Checkpoint`.""" input_atomic = self.batch.atomic input = tuple(self.batch) # Use a phony which requires grad to ensure that Checkpoint can be # tracked by the autograd engine even when none of the input tensors # require grad. phony = get_phony(self.batch[0].device, requires_grad=True) output = Checkpoint.apply(phony, self.recomputed, self.rng_states, self.function, input_atomic, *input) return Batch(output)
def test_batch_atomic(): x = torch.tensor(42) b = Batch(x) assert b.atomic assert b.tensor is x with pytest.raises(AttributeError): b.tensors assert list(b) == [x] assert len(b) == 1 assert b[0] is x
def test_tensor_life_without_checkpointing(): skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, 'test'): (0, 1)}) skip_tracker = SkipTrackerThroughPotals(skip_layout) batch = Batch(torch.tensor([1.0])) tensor = torch.tensor([2.0]) skip_tracker.save(batch, None, 'test', tensor) assert skip_tracker.portals[(None, 'test')].tensor_life == 1 skip_tracker.load(batch, None, 'test') assert skip_tracker.portals[(None, 'test')].tensor_life == 0
def test_batch_non_atomic(): x, y = torch.tensor(42), torch.tensor(21) b = Batch((x, y)) assert not b.atomic with pytest.raises(AttributeError): b.tensor assert b.tensors == (x, y) assert list(b) == [x, y] assert len(b) == 2 assert b[0] is x assert b[1] is y
def test_reuse_portal(): skip_layout = SkipLayout(num_partitions=2, skip_routes={(None, 'test'): (0, 1)}) skip_tracker = SkipTrackerThroughPotals(skip_layout) batch = Batch(torch.tensor([1.0])) a = torch.tensor([2.0]) b = torch.tensor([2.0]) skip_tracker.save(batch, None, 'test', a) portal = skip_tracker.portals[(None, 'test')] skip_tracker.save(batch, None, 'test', b) assert portal is skip_tracker.portals[(None, 'test')]
def test_batch_call(): a = Batch(torch.tensor(42)) b = Batch((torch.tensor(42), torch.tensor(21))) def f(x): return x assert a.call(f).atomic assert not b.call(f).atomic
def test_not_requires_grad(): x = Batch(torch.rand(1, requires_grad=False)) assert not x[0].requires_grad def f(x): return x * 2 chk = Checkpointing(f, x) x = chk.checkpoint() assert x[0].requires_grad chk.recompute(x) assert x[0].requires_grad x.tensor.backward()
def profile_sizes( module: nn.Sequential, input: TensorOrTensors, chunks: int, param_scale: float, device: torch.device, ) -> List[int]: """Profiles CUDA memory usage per layer.""" if device.type != 'cuda': raise ValueError('size profiler supports only CUDA device') batch = Batch(input) sizes: List[int] = [] latent_scale = batch[0].size(0) / chunks for i, x in enumerate(batch): batch[i] = x[:1].detach().to(device).requires_grad_(x.requires_grad) for layer in layerwise_sandbox(module, device): detach(batch) # Detect memory usage at forward. memory_before = torch.cuda.memory_allocated(device) batch = batch.call(layer) memory_after = torch.cuda.memory_allocated(device) latent_size = memory_after - memory_before # Analyze size of parameters. param_size = sum(p.storage().size() * p.storage().element_size() for p in layer.parameters()) # Combine size of parameters and activations with normalize scales. size = latent_size * latent_scale + param_size * param_scale sizes.append(int(size)) return sizes
def test_no_copy_no_portal(): skip_layout = SkipLayout(num_partitions=2, skip_routes={ (None, 'copy'): (0, 1), (None, 'not_copy'): (0, 0), }) skip_tracker = SkipTrackerThroughPotals(skip_layout) batch = Batch(torch.tensor([1.0])) a = torch.tensor([2.0]) b = torch.tensor([2.0]) skip_tracker.save(batch, None, 'copy', a) skip_tracker.save(batch, None, 'not_copy', b) assert (None, 'copy') in skip_tracker.portals assert (None, 'copy') not in skip_tracker.tensors assert (None, 'not_copy') in skip_tracker.tensors assert (None, 'not_copy') not in skip_tracker.portals
def profile_times( module: nn.Sequential, sample: TensorOrTensors, timeout: float, device: torch.device, ) -> List[int]: """Profiles elapsed times per layer.""" if any(p.grad is not None for p in module.parameters()): raise ValueError('some parameter already has gradient') _batch = Batch(sample) for i, x in enumerate(_batch): _batch[i] = x.detach().to(device).requires_grad_(x.requires_grad) time_bufs: List[List[float]] = [[] for _ in module] begun_at = time.time() while time.time() - begun_at < timeout: batch = _batch for i, layer in enumerate(layerwise_sandbox(module, device)): detach(batch) if device.type == 'cuda': torch.cuda.synchronize(device) tick = time.time() # Forward batch = batch.call(layer) # Backward backward_tensors = tuple(y for y in batch if y.requires_grad) if backward_tensors: torch.autograd.backward(backward_tensors, backward_tensors) if device.type == 'cuda': torch.cuda.synchronize(device) tock = time.time() time_bufs[i].append(tock - tick) us = 1_000_000 return [sum(int(t * us) for t in buf) for buf in time_bufs]
def compute(batch: Batch = batch, partition: nn.Sequential = partition) -> Batch: return batch.call(partition)
def _42(): return Batch(torch.tensor(42))
def log_thread_id(): thread_id = threading.current_thread().ident thread_ids.add(thread_id) return Batch(())
def counter(): nonlocal count time.sleep(0.1) count += 1 return Batch(())
def detect_grad_enabled(): x = torch.rand(1, requires_grad=torch.is_grad_enabled()) return Batch(x)