Beispiel #1
0
    def __init__(
        self,
        batches: List[Batch],
        partitions: List[nn.Sequential],
        devices: Optional[List[torch.device]] = None,
        copy_streams: Optional[List[List[AbstractStream]]] = None,
        skip_layout: Optional[SkipLayout] = None,
        checkpoint_stop: int = 0,
    ) -> None:
        self.batches = batches
        self.partitions = partitions

        if devices is None:
            devices = [torch.device('cpu') for _ in partitions]
        self.devices = devices

        if copy_streams is None:
            copy_streams = [[current_stream(d)] * len(batches)
                            for d in devices]
        self.copy_streams = copy_streams

        if skip_layout is None:
            skip_layout = inspect_skip_layout(partitions)

        self.skip_layout = skip_layout
        self.checkpoint_stop = checkpoint_stop
Beispiel #2
0
def test_adjoining_partitions():
    p1 = nn.Sequential(StashFoo())
    p2 = nn.Sequential(PopFoo())

    layout = inspect_skip_layout([p1, p2])
    policy = [list(layout.copy_policy(i)) for i in range(2)]

    assert policy == [[], [(0, None, 'foo')]]
Beispiel #3
0
def test_inner_partition():
    p1 = nn.Sequential(StashFoo(), PopFoo())
    p2 = nn.Sequential(Pass())

    layout = inspect_skip_layout([p1, p2])
    policy = [list(layout.copy_policy(i)) for i in range(2)]

    assert policy == [[], []]
Beispiel #4
0
def test_no_skippables():
    p1 = nn.Sequential(Pass())
    p2 = nn.Sequential(Pass())

    layout = inspect_skip_layout([p1, p2])
    policy = [list(layout.copy_policy(i)) for i in range(2)]

    assert policy == [[], []]
Beispiel #5
0
def test_pop_2_from_different_partitions():
    p1 = nn.Sequential(StashFoo())
    p2 = nn.Sequential(StashBar())
    p3 = nn.Sequential(PopBar(), PopFoo())

    layout = inspect_skip_layout([p1, p2, p3])
    policy = [list(layout.copy_policy(i)) for i in range(3)]

    # p3 pops 'bar' before 'foo', but the plan is sorted by source partition index.
    assert policy == [[], [], [(0, None, 'foo'), (1, None, 'bar')]]
Beispiel #6
0
def test_namespace():
    ns1 = Namespace()
    ns2 = Namespace()

    p1 = nn.Sequential(StashFoo().isolate(ns1))
    p2 = nn.Sequential(StashFoo().isolate(ns2))
    p3 = nn.Sequential(PopFoo().isolate(ns2), PopFoo().isolate(ns1))

    layout = inspect_skip_layout([p1, p2, p3])
    policy = [list(layout.copy_policy(i)) for i in range(3)]

    # p3 pops 'bar' before 'foo', but the plan is sorted by source partition index.
    assert policy == [[], [], [(0, ns1, 'foo'), (1, ns2, 'foo')]]
Beispiel #7
0
    def __init__(
        self,
        module: nn.Sequential,
        balance: Optional[Iterable[int]] = None,
        *,
        devices: Optional[Devices] = None,
        chunks: int = chunks,
        checkpoint: str = checkpoint,
        deferred_batch_norm: bool = False,
    ) -> None:
        super().__init__()

        chunks = int(chunks)
        checkpoint = str(checkpoint)

        if balance is None:
            raise ValueError(recommend_auto_balance('balance is required'))
        if chunks <= 0:
            raise ValueError('number of chunks must be positive integer')
        if checkpoint not in ['always', 'except_last', 'never']:
            raise ValueError(
                "checkpoint is not one of 'always', 'except_last', or 'never'")

        verify_module(module)

        # Verify if the underlying skippable modules satisfy integrity. The
        # integrity can be verified before forward() because it is static.
        verify_skippables(module)

        self.chunks = chunks
        self.checkpoint = checkpoint

        if deferred_batch_norm:
            module = DeferredBatchNorm.convert_deferred_batch_norm(
                module, chunks)

        if devices is None:
            devices = range(torch.cuda.device_count())
        devices = [torch.device(d) for d in devices]
        devices = cast(List[torch.device], devices)

        try:
            self.partitions, self.balance, self.devices = split_module(
                module, balance, devices)
        except BalanceError as exc:
            raise ValueError(recommend_auto_balance(str(exc)))

        self._copy_streams: List[List[AbstractStream]] = []
        self._skip_layout = inspect_skip_layout(self.partitions)