コード例 #1
0
    def async_send_inner(self, batch: Batch, index: int) -> Tuple[Batch, PipeMessage]:
        task = create_task_without_skip_trackers(
            self.checkpoint_stop, index, self.group.rank(), batch, self.partitions[0].module,
        )
        result = task.compute()
        task.finalize(result)

        ranks = get_pipeline_parallel_ranks()
        this_rank = torch.distributed.get_rank()

        body = AsyncMessageBody(
            AsyncMessageType.Activations,
            index,
            Location(this_rank, 0),
            Location(ranks[ranks.index(this_rank) + 1], 0),
            0,
        )
        message = PipeMessage(
            this_rank,
            ranks[ranks.index(this_rank) + 1],
            queue_name=EVENT_LOOP_ACTIVATIONS_QUEUE,
            args=body,
            tensors=tuple([*result]),
        )
        return result, message
コード例 #2
0
ファイル: test_pipe.py プロジェクト: wns823/fairscale
    def check_partitions(model, balance, expected_order, expected_ranks):
        """Check the instantiated model matches expectation of order and rank

        model: a list of modules or an nn.Sequential
        balance: the balance argument to Pipe
        expected_order: the index of modules in `model` in the order they will
            be executed, grouped by nn.Sequential
        expected_rank: the rank that each module will be executed on
        """

        invocations = []
        invocation_wrapper = dict()

        # Collect `Invocation` and `Invocation` -> `ModuleWrapper` mapping from
        # instantiated model
        for rank in range(len(balance)):
            instantiated = instantiate_partition(model, balance,
                                                 FakeGroup(rank, len(balance)),
                                                 Pipe.AsyncSchedule)
            for part in instantiated:
                assert isinstance(part.module, nn.Sequential)
                for inv in part.invocations:
                    invocations.append(inv)
                    invocation_wrapper[inv] = part

        modules = []
        prev = None
        current = Location(0, 0)
        ranks = []

        for order, inv in enumerate(sorted(invocations,
                                           key=lambda x: x.order)):
            # Check integrity of Location chain
            assert inv.order == order
            assert inv.source == prev
            assert inv.this == current
            prev = inv.this
            current = inv.dest
            modules.append(list(invocation_wrapper[inv].module.children()))
            ranks.append(inv.this.stage)

        # assert len(modules) == len(expected_order)
        for left, right in zip(modules, expected_order):
            assert len(left) == len(right), f"{right}"
            assert list(map(id,
                            left)) == list(map(id,
                                               (model[e]
                                                for e in right))), f"{right}"

        assert ranks == expected_ranks