Exemplo n.º 1
0
def test_gather_tensors():
    a = torch.zeros(1, 1)
    b = torch.zeros(1, 1)

    ab = gather([Batch(a), Batch(b)])

    assert ab.size() == (2, 1)
Exemplo n.º 2
0
def profile_sequential_module(
    module: nn.Sequential,
    input: Union[Tensor, Sequence[Tensor]],
    chunks: int,
    param_scale: float,
    device: torch.device,
) -> List[int]:
    """similar to 'profile_sizes' function in torchgpipe, but instead of
    passing in a batch of size 1, it passes in a whole batch for more
    accurate estimate of the sizes; moreover, it fixed the issue with
    negative memory allocation for latent variables

    reference: torchgpipe.balance.profile.profile_sizes

    :param module: pytorch sequential module to be profiled
    :type module: nn.Sequential
    :param input: input tensor or a sequence (will be cast to tuple) of tensors
    :type input: Union[Tensor, Sequence[Tensor]]
    :param chunks: number of chunks for a single batch specified in GPipe
    :type chunks: int
    :param param_scale: scaling factor for parameters (SGD: 2-3, Adam: 4-5,
    etc.); check GPipe doc for more details
    more details
    :type param_scale: float
    :param device: device for size profiling run; must be GPU
    :type device: torch.device
    :return: list of integers representing the sizes of all the layers in
    sequential model in bytes
    :rtype: List[int]
    """
    if device.type != 'cuda':
        raise ValueError('require CUDA device for size profiler supports '
                         'only CUDA device')

    # cast everything in the batch into a tuple of tensors if the given
    # input is a sequence of tensors
    _batch = Batch(input) if isinstance(input, Tensor) else \
        Batch(tuple([_i.detach().to(device) for _i in input]))
    _layer_sizes_in_byte: List[int] = []

    for layer in layerwise_sandbox(module, device):
        detach(_batch)

        # Detect memory usage at forward.
        _memory_before = torch.cuda.memory_allocated(device)
        _batch = _batch.call(layer)
        _memory_after = torch.cuda.memory_allocated(device)
        _latent_size = max(0, _memory_after - _memory_before)

        # Analyze size of parameters.
        param_size = sum(p.storage().size() * p.storage().element_size()
                         for p in layer.parameters())

        # Combine size of parameters and activations with normalize
        # scales.
        _size = _latent_size / chunks + param_size * param_scale
        _layer_sizes_in_byte.append(int(_size))

    return _layer_sizes_in_byte
Exemplo n.º 3
0
    def forward(self,
                input: TensorOrTensors) -> TensorOrTensors:  # type: ignore
        """Performs the forward propagation. :class:`stash` or :class:`pop`
        commands will be handled by portals silently. The portals won't be
        exposed to users.

        Raises:
            RuntimeError:
                illegal 'stash' or 'pop' is found.

        """
        skip_tracker = current_skip_tracker()
        stashed_tensors: Dict[str, Optional[Tensor]] = {}

        # Load skip tensors that might be popped.
        poppable_tensors = {}
        batch = Batch(input)
        for ns, name in self.poppable():
            try:
                poppable_tensors[name] = skip_tracker.load(batch, ns, name)
            except KeyError:
                raise RuntimeError(f"'{name}' has not been stashed")
        input = batch.tensor_or_tensors

        # Handle skip commands.
        def handle_stash(name: str, tensor: Optional[Tensor]) -> None:
            if name not in self.stashable_names:
                raise RuntimeError(
                    f"'{name}' has not been declared as stashable")
            stashed_tensors[name] = tensor

        def handle_pop(name: str) -> Optional[Tensor]:
            if name not in self.poppable_names:
                raise RuntimeError(
                    f"'{name}' has not been declared as poppable")
            return poppable_tensors.pop(name)

        output = self.dispatch(input, handle_stash, handle_pop)

        # All declared skips must be stashed or popped.
        not_stashed = self.stashable_names - stashed_tensors.keys()
        if not_stashed:
            comma_names = ', '.join("'%s'" % n for n in not_stashed)
            raise RuntimeError(f'{comma_names} must be stashed but have not')

        not_popped = poppable_tensors.keys()
        if not_popped:
            comma_names = ', '.join("'%s'" % n for n in not_popped)
            raise RuntimeError(f'{comma_names} must be popped but have not')

        # Save stashed skip tensors.
        batch = Batch(output)
        for ns, name in self.stashable():
            tensor = stashed_tensors[name]
            skip_tracker.save(batch, ns, name, tensor)
        output = batch.tensor_or_tensors

        return output
Exemplo n.º 4
0
def test_gather_tuples():
    a = (torch.zeros(1, 1), torch.zeros(2, 2))
    b = (torch.zeros(1, 1), torch.zeros(2, 2))

    ab = gather([Batch(a), Batch(b)])

    assert isinstance(ab, tuple)
    assert ab[0].size() == (2, 1)
    assert ab[1].size() == (4, 2)
Exemplo n.º 5
0
def test_batch_setitem_by_slice():
    a = Batch(torch.tensor(42))
    b = Batch((torch.tensor(42), torch.tensor(21)))

    a[:] = (torch.tensor(0), )
    b[:] = (torch.tensor(0), )

    assert a.atomic
    assert a[0].item() == 0

    assert not b.atomic
    assert len(b) == 1
    assert b[0].item() == 0
Exemplo n.º 6
0
def test_batch_setitem_by_index():
    a = Batch(torch.tensor(42))
    b = Batch((torch.tensor(42), torch.tensor(21)))

    a[0] = torch.tensor(0)
    b[0] = torch.tensor(0)

    assert a.atomic
    assert a[0].item() == 0

    assert not b.atomic
    assert len(b) == 2
    assert b[0].item() == 0
    assert b[1].item() == 21
Exemplo n.º 7
0
def test_forward_lockstep():
    timeline = []

    class DelayedLog(nn.Module):
        def __init__(self, j, seconds):
            super().__init__()
            self.i = 0
            self.j = j
            self.seconds = seconds

        def forward(self, x):
            time.sleep(self.seconds)

            timeline.append((self.i, self.j))
            self.i += 1

            return x

    batches = [Batch(torch.rand(1, 1)) for _ in range(3)]
    partitions = [
        nn.Sequential(DelayedLog(0, seconds=0)),
        nn.Sequential(DelayedLog(1, seconds=0.1))
    ]

    pipeline = Pipeline(batches, partitions)
    pipeline.run()

    # Expected timeline: (Logs are recorded at !)
    #
    # Partition #0: 0! 1!   2!
    # Partition #1:    000! 111! 222!
    #
    assert timeline == [(0, 0), (1, 0), (0, 1), (2, 0), (1, 1), (2, 1)]
Exemplo n.º 8
0
 def compute(
     batch: Batch = batch,
     partition: nn.Sequential = partition,
     skip_tracker: SkipTrackerThroughPotals = skip_trackers[i],
 ) -> Batch:
     with use_skip_tracker(skip_tracker):
         return batch.call(partition)
Exemplo n.º 9
0
    def checkpoint(self) -> Batch:
        """Returns a batch applied by :class:`Checkpoint`."""
        input_atomic = self.batch.atomic
        input = tuple(self.batch)

        phony = Checkpointing.phonies[self.batch[0].device]
        output = Checkpoint.apply(phony, self.recomputed, self.function,
                                  input_atomic, *input)
        return Batch(output)
Exemplo n.º 10
0
def checkpoint(function: Function, input: TensorOrTensors) -> TensorOrTensors:
    """Makes a checkpoint with a simple interface like
    :func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug
    :class:`Checkpoint` and :class:`Recompute` without boilerplate.
    """
    batch = Batch(input)

    chk = Checkpointing(function, batch)
    batch = chk.checkpoint()
    chk.recompute(batch)

    return batch.tensor_or_tensors
Exemplo n.º 11
0
    def checkpoint(self) -> Batch:
        """Returns a batch applied by :class:`Checkpoint`."""
        input_atomic = self.batch.atomic
        input = tuple(self.batch)

        # Use a phony which requires grad to ensure that Checkpoint can be
        # tracked by the autograd engine even when none of the input tensors
        # require grad.
        phony = get_phony(self.batch[0].device, requires_grad=True)

        output = Checkpoint.apply(phony, self.recomputed, self.rng_states,
                                  self.function, input_atomic, *input)
        return Batch(output)
Exemplo n.º 12
0
def test_batch_atomic():
    x = torch.tensor(42)
    b = Batch(x)

    assert b.atomic

    assert b.tensor is x
    with pytest.raises(AttributeError):
        b.tensors

    assert list(b) == [x]
    assert len(b) == 1
    assert b[0] is x
Exemplo n.º 13
0
def test_tensor_life_without_checkpointing():
    skip_layout = SkipLayout(num_partitions=2,
                             skip_routes={(None, 'test'): (0, 1)})
    skip_tracker = SkipTrackerThroughPotals(skip_layout)

    batch = Batch(torch.tensor([1.0]))
    tensor = torch.tensor([2.0])

    skip_tracker.save(batch, None, 'test', tensor)
    assert skip_tracker.portals[(None, 'test')].tensor_life == 1

    skip_tracker.load(batch, None, 'test')
    assert skip_tracker.portals[(None, 'test')].tensor_life == 0
Exemplo n.º 14
0
def test_batch_non_atomic():
    x, y = torch.tensor(42), torch.tensor(21)
    b = Batch((x, y))

    assert not b.atomic

    with pytest.raises(AttributeError):
        b.tensor
    assert b.tensors == (x, y)

    assert list(b) == [x, y]
    assert len(b) == 2
    assert b[0] is x
    assert b[1] is y
Exemplo n.º 15
0
def test_reuse_portal():
    skip_layout = SkipLayout(num_partitions=2,
                             skip_routes={(None, 'test'): (0, 1)})
    skip_tracker = SkipTrackerThroughPotals(skip_layout)

    batch = Batch(torch.tensor([1.0]))
    a = torch.tensor([2.0])
    b = torch.tensor([2.0])

    skip_tracker.save(batch, None, 'test', a)
    portal = skip_tracker.portals[(None, 'test')]

    skip_tracker.save(batch, None, 'test', b)
    assert portal is skip_tracker.portals[(None, 'test')]
Exemplo n.º 16
0
def test_batch_call():
    a = Batch(torch.tensor(42))
    b = Batch((torch.tensor(42), torch.tensor(21)))

    def f(x):
        return x

    assert a.call(f).atomic
    assert not b.call(f).atomic
Exemplo n.º 17
0
def test_not_requires_grad():
    x = Batch(torch.rand(1, requires_grad=False))
    assert not x[0].requires_grad

    def f(x):
        return x * 2

    chk = Checkpointing(f, x)
    x = chk.checkpoint()
    assert x[0].requires_grad

    chk.recompute(x)
    assert x[0].requires_grad

    x.tensor.backward()
Exemplo n.º 18
0
def profile_sizes(
    module: nn.Sequential,
    input: TensorOrTensors,
    chunks: int,
    param_scale: float,
    device: torch.device,
) -> List[int]:
    """Profiles CUDA memory usage per layer."""
    if device.type != 'cuda':
        raise ValueError('size profiler supports only CUDA device')

    batch = Batch(input)
    sizes: List[int] = []

    latent_scale = batch[0].size(0) / chunks
    for i, x in enumerate(batch):
        batch[i] = x[:1].detach().to(device).requires_grad_(x.requires_grad)

    for layer in layerwise_sandbox(module, device):
        detach(batch)

        # Detect memory usage at forward.
        memory_before = torch.cuda.memory_allocated(device)
        batch = batch.call(layer)
        memory_after = torch.cuda.memory_allocated(device)
        latent_size = memory_after - memory_before

        # Analyze size of parameters.
        param_size = sum(p.storage().size() * p.storage().element_size()
                         for p in layer.parameters())

        # Combine size of parameters and activations with normalize scales.
        size = latent_size * latent_scale + param_size * param_scale
        sizes.append(int(size))

    return sizes
Exemplo n.º 19
0
def test_no_copy_no_portal():
    skip_layout = SkipLayout(num_partitions=2,
                             skip_routes={
                                 (None, 'copy'): (0, 1),
                                 (None, 'not_copy'): (0, 0),
                             })
    skip_tracker = SkipTrackerThroughPotals(skip_layout)

    batch = Batch(torch.tensor([1.0]))
    a = torch.tensor([2.0])
    b = torch.tensor([2.0])

    skip_tracker.save(batch, None, 'copy', a)
    skip_tracker.save(batch, None, 'not_copy', b)

    assert (None, 'copy') in skip_tracker.portals
    assert (None, 'copy') not in skip_tracker.tensors
    assert (None, 'not_copy') in skip_tracker.tensors
    assert (None, 'not_copy') not in skip_tracker.portals
Exemplo n.º 20
0
def profile_times(
    module: nn.Sequential,
    sample: TensorOrTensors,
    timeout: float,
    device: torch.device,
) -> List[int]:
    """Profiles elapsed times per layer."""
    if any(p.grad is not None for p in module.parameters()):
        raise ValueError('some parameter already has gradient')

    _batch = Batch(sample)
    for i, x in enumerate(_batch):
        _batch[i] = x.detach().to(device).requires_grad_(x.requires_grad)

    time_bufs: List[List[float]] = [[] for _ in module]
    begun_at = time.time()

    while time.time() - begun_at < timeout:
        batch = _batch

        for i, layer in enumerate(layerwise_sandbox(module, device)):
            detach(batch)

            if device.type == 'cuda':
                torch.cuda.synchronize(device)
            tick = time.time()

            # Forward
            batch = batch.call(layer)

            # Backward
            backward_tensors = tuple(y for y in batch if y.requires_grad)
            if backward_tensors:
                torch.autograd.backward(backward_tensors, backward_tensors)

            if device.type == 'cuda':
                torch.cuda.synchronize(device)
            tock = time.time()

            time_bufs[i].append(tock - tick)

    us = 1_000_000
    return [sum(int(t * us) for t in buf) for buf in time_bufs]
Exemplo n.º 21
0
 def compute(batch: Batch = batch,
             partition: nn.Sequential = partition) -> Batch:
     return batch.call(partition)
Exemplo n.º 22
0
 def _42():
     return Batch(torch.tensor(42))
Exemplo n.º 23
0
 def log_thread_id():
     thread_id = threading.current_thread().ident
     thread_ids.add(thread_id)
     return Batch(())
Exemplo n.º 24
0
 def counter():
     nonlocal count
     time.sleep(0.1)
     count += 1
     return Batch(())
Exemplo n.º 25
0
 def detect_grad_enabled():
     x = torch.rand(1, requires_grad=torch.is_grad_enabled())
     return Batch(x)