Exemplo n.º 1
0
 def test_split_tensor(self):
     t = torch.randn(32, 5)
     for st in PipelineHelper.split(t, 8):
         assert st.shape == (8, 5)
     a, b = PipelineHelper.split(t, 17)
     assert a.shape == (17, 5)
     assert b.shape == (15, 5)
Exemplo n.º 2
0
 def test_split_badcases(self):
     # test some cases that cause infinite loops if we don't catch them.
     t = torch.randn(32, 5)
     with self.assertRaises(ValueError):
         PipelineHelper.split((t, {'x': {}}), 8)
     with self.assertRaises(ValueError):
         PipelineHelper.split((t, {'y': ()}), 8)
Exemplo n.º 3
0
    def _apply_model_parallel(self, tensor, encoder_output, encoder_mask, incr_state):
        """
        Pipeline application of model parallelism.
        """
        chunks = PipelineHelper.split(
            (tensor, encoder_output, encoder_mask, incr_state)
        )
        work_items = PipelineHelper.schedule_work_items(self.layers, chunks)

        new_incr_state = {i: [] for i, _ in enumerate(self.layers)}

        for chunk_idx, layer_nos, next_device in work_items:
            s_tensor, s_enc_out, s_enc_mask, s_incr_state = chunks[chunk_idx]
            for layer_no in layer_nos:
                s_tensor, nis = self.layers[layer_no](
                    x=s_tensor,
                    encoder_output=s_enc_out,
                    encoder_mask=s_enc_mask,
                    incr_state=s_incr_state.get(layer_no),
                )
                new_incr_state[layer_no].append(nis)
            # don't move incr state, it's always on the correct device
            s_tensor, s_enc_out, s_enc_mask = PipelineHelper.chunk_to(
                (s_tensor, s_enc_out, s_enc_mask), next_device
            )
            chunks[chunk_idx] = (s_tensor, s_enc_out, s_enc_mask, s_incr_state)

        tensor_out = PipelineHelper.join([c[0] for c in chunks])
        new_incr_state = {
            layer_no: PipelineHelper.join(pieces)
            for layer_no, pieces in new_incr_state.items()
        }

        return tensor_out, new_incr_state
Exemplo n.º 4
0
    def _apply_model_parallel(self, tensor, encoder_output, encoder_mask,
                              incr_state):
        """
        Pipeline application of model parallelism.
        """
        chunks = PipelineHelper.split(
            (tensor, encoder_output, encoder_mask, incr_state))
        work_items = PipelineHelper.schedule_work_items(self.layers, chunks)

        new_incr_state = [{} for _ in chunks]

        for chunk_idx, layer_nos, next_device in work_items:
            s_tensor, s_enc_out, s_enc_mask, s_incr_state = chunks[chunk_idx]
            for layer_no in layer_nos:
                s_tensor, new_incr_state[chunk_idx][layer_no] = self.layers[
                    layer_no](
                        x=s_tensor,
                        encoder_output=s_enc_out,
                        encoder_mask=s_enc_mask,
                        incr_state=s_incr_state.get(layer_no),
                    )
            chunks[chunk_idx] = PipelineHelper.chunk_to(
                (s_tensor, s_enc_out, s_enc_mask, s_incr_state), next_device)

        tensor_out = PipelineHelper.join([c[0] for c in chunks])
        new_incr_state = PipelineHelper.join(new_incr_state)

        return tensor_out, new_incr_state
Exemplo n.º 5
0
    def test_schedule_work_items(self):
        # test that we schedule things correctly
        # pretend we have 8 layers and 4 gpus, and they are unevenly distributed
        model = torch.nn.ModuleList()
        for i in range(8):
            layer = IdentityLayer()
            if i == 0:
                layer._mp_gpu = 'cuda:0'
            elif i in (1, 2, 3):
                layer._mp_gpu = 'cuda:1'
            elif i in (4, 5):
                layer._mp_gpu = 'cuda:2'
            elif i in (6, 7):
                layer._mp_gpu = 'cuda:3'
            model.append(layer)

        # there are 2 chunks, each 16 x 7 in size
        chunks = PipelineHelper.split(torch.randn(32, 7), 16)

        work_items = list(PipelineHelper.schedule_work_items(model, chunks))
        assert len(work_items) == 8
        assert work_items[0].layer_nos == [0] and work_items[0].chunk_idx == 0
        assert work_items[1].layer_nos == [1, 2, 3
                                           ] and work_items[1].chunk_idx == 0
        assert work_items[2].layer_nos == [0] and work_items[2].chunk_idx == 1
        assert work_items[3].layer_nos == [4, 5
                                           ] and work_items[3].chunk_idx == 0
        assert work_items[4].layer_nos == [1, 2, 3
                                           ] and work_items[4].chunk_idx == 1
        assert work_items[5].layer_nos == [6, 7
                                           ] and work_items[5].chunk_idx == 0
        assert work_items[6].layer_nos == [4, 5
                                           ] and work_items[6].chunk_idx == 1
        assert work_items[7].layer_nos == [6, 7
                                           ] and work_items[7].chunk_idx == 1
Exemplo n.º 6
0
 def test_split_tuple(self):
     t = torch.randn(32, 5)
     tup = (t, t, t)
     for stup in PipelineHelper.split(tup, 8):
         assert isinstance(stup, tuple)
         assert len(stup) == 3
         for i in range(3):
             assert stup[i].shape == (8, 5)
Exemplo n.º 7
0
 def test_split_dict(self):
     t = torch.randn(32, 5)
     d = {'x': t, 'y': t}
     for sd in PipelineHelper.split(d, 8):
         assert isinstance(sd, dict)
         assert 'x' in sd
         assert 'y' in sd
         assert sd['x'].shape == (8, 5)
         assert sd['y'].shape == (8, 5)
Exemplo n.º 8
0
    def _apply_model_parallel_with_extra(
        self,
        tensor,
        encoder_output,
        encoder_mask,
        incr_state,
        extra_output: torch.Tensor = None,
        extra_mask: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Copy paste from TransformerDecoder._apply_model_parallel while incorporating the
        extra output/extra mask.
        """
        chunks = PipelineHelper.split(
            (tensor, encoder_output, encoder_mask, incr_state, extra_output, extra_mask)
        )
        work_items = PipelineHelper.schedule_work_items(self.layers, chunks)

        new_incr_state = {i: [] for i, _ in enumerate(self.layers)}

        for chunk_idx, layer_nos, next_device in work_items:
            s_tensor, s_enc_out, s_enc_mask, s_incr_state, s_extra_out, s_extra_mask = chunks[
                chunk_idx
            ]
            for layer_no in layer_nos:
                s_tensor, nis = self.layers[layer_no](
                    x=s_tensor,
                    encoder_output=s_enc_out,
                    encoder_mask=s_enc_mask,
                    incr_state=s_incr_state.get(layer_no),
                    extra_output=s_extra_out,
                    extra_mask=s_extra_mask,
                )
                new_incr_state[layer_no].append(nis)
            # don't move incr state, it's always on the correct device
            s_tensor, s_enc_out, s_enc_mask, s_extra_out, s_extra_mask = PipelineHelper.chunk_to(
                (s_tensor, s_enc_out, s_enc_mask, s_extra_out, s_extra_mask),
                next_device,
            )
            chunks[chunk_idx] = (
                s_tensor,
                s_enc_out,
                s_enc_mask,
                s_incr_state,
                s_extra_out,
                s_extra_mask,
            )

        tensor_out = PipelineHelper.join([c[0] for c in chunks])
        new_incr_state = {
            layer_no: PipelineHelper.join(pieces)
            for layer_no, pieces in new_incr_state.items()
        }

        return tensor_out, new_incr_state  # type: ignore
Exemplo n.º 9
0
 def test_split_complex(self):
     t = torch.randn(32, 5)
     item = (t, {'x': t, 'y': t})
     for sitem in PipelineHelper.split(item, 8):
         assert isinstance(sitem, tuple)
         assert len(sitem) == 2
         left, right = sitem
         assert isinstance(left, torch.Tensor)
         assert left.shape == (8, 5)
         assert isinstance(right, dict)
         assert 'x' in right
         assert 'y' in right
         assert right['x'].shape == (8, 5)
         assert right['y'].shape == (8, 5)
Exemplo n.º 10
0
    def _apply_model_parallel(self, tensor, mask):
        """
        Pipeline application of model parallelism.
        """
        chunks = PipelineHelper.split((tensor, mask))
        work_items = PipelineHelper.schedule_work_items(self.layers, chunks)

        for chunk_idx, layer_nos, next_device in work_items:
            s_tensor, s_mask = chunks[chunk_idx]
            for layer_no in layer_nos:
                s_tensor = self.layers[layer_no](s_tensor, s_mask)
            chunks[chunk_idx] = PipelineHelper.chunk_to((s_tensor, s_mask), next_device)

        tensor_out, mask_out = PipelineHelper.join(chunks)
        return tensor_out
Exemplo n.º 11
0
 def test_split_emptydict(self):
     # test a horrible edge case where d is an empty dict, and we need to
     # return a BUNCH of empty dicts
     t = torch.randn(32, 5)
     d = {}
     tup = (t, d)
     items = PipelineHelper.split(tup, 8)
     assert len(items) == 4
     for item in items:
         assert isinstance(item, tuple)
         a, b = item
         assert isinstance(a, torch.Tensor)
         assert a.shape == (8, 5)
         assert isinstance(b, dict)
         assert b == {}
Exemplo n.º 12
0
    def _apply_model_parallel(self, tensor, mask,
                              **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Override to return attention weights.
        """
        chunks = PipelineHelper.split((tensor, mask))
        work_items = PipelineHelper.schedule_work_items(self.layers, chunks)

        for chunk_idx, layer_nos, next_device in work_items:
            s_weights = None
            try:
                s_tensor, s_mask = chunks[chunk_idx]
            except ValueError:
                s_tensor, s_mask, s_weights = chunks[chunk_idx]
            for layer_no in layer_nos:
                s_tensor, s_weights = self.layers[layer_no](s_tensor, s_mask,
                                                            **kwargs)
            chunks[chunk_idx] = PipelineHelper.chunk_to(
                (s_tensor, s_mask, s_weights), next_device)
        joined = PipelineHelper.join(chunks)
        tensor_out, out_mask, weights = joined
        return tensor_out, weights