def test_split_tensor(self): t = torch.randn(32, 5) for st in PipelineHelper.split(t, 8): assert st.shape == (8, 5) a, b = PipelineHelper.split(t, 17) assert a.shape == (17, 5) assert b.shape == (15, 5)
def test_split_badcases(self): # test some cases that cause infinite loops if we don't catch them. t = torch.randn(32, 5) with self.assertRaises(ValueError): PipelineHelper.split((t, {'x': {}}), 8) with self.assertRaises(ValueError): PipelineHelper.split((t, {'y': ()}), 8)
def _apply_model_parallel(self, tensor, encoder_output, encoder_mask, incr_state): """ Pipeline application of model parallelism. """ chunks = PipelineHelper.split( (tensor, encoder_output, encoder_mask, incr_state) ) work_items = PipelineHelper.schedule_work_items(self.layers, chunks) new_incr_state = {i: [] for i, _ in enumerate(self.layers)} for chunk_idx, layer_nos, next_device in work_items: s_tensor, s_enc_out, s_enc_mask, s_incr_state = chunks[chunk_idx] for layer_no in layer_nos: s_tensor, nis = self.layers[layer_no]( x=s_tensor, encoder_output=s_enc_out, encoder_mask=s_enc_mask, incr_state=s_incr_state.get(layer_no), ) new_incr_state[layer_no].append(nis) # don't move incr state, it's always on the correct device s_tensor, s_enc_out, s_enc_mask = PipelineHelper.chunk_to( (s_tensor, s_enc_out, s_enc_mask), next_device ) chunks[chunk_idx] = (s_tensor, s_enc_out, s_enc_mask, s_incr_state) tensor_out = PipelineHelper.join([c[0] for c in chunks]) new_incr_state = { layer_no: PipelineHelper.join(pieces) for layer_no, pieces in new_incr_state.items() } return tensor_out, new_incr_state
def _apply_model_parallel(self, tensor, encoder_output, encoder_mask, incr_state): """ Pipeline application of model parallelism. """ chunks = PipelineHelper.split( (tensor, encoder_output, encoder_mask, incr_state)) work_items = PipelineHelper.schedule_work_items(self.layers, chunks) new_incr_state = [{} for _ in chunks] for chunk_idx, layer_nos, next_device in work_items: s_tensor, s_enc_out, s_enc_mask, s_incr_state = chunks[chunk_idx] for layer_no in layer_nos: s_tensor, new_incr_state[chunk_idx][layer_no] = self.layers[ layer_no]( x=s_tensor, encoder_output=s_enc_out, encoder_mask=s_enc_mask, incr_state=s_incr_state.get(layer_no), ) chunks[chunk_idx] = PipelineHelper.chunk_to( (s_tensor, s_enc_out, s_enc_mask, s_incr_state), next_device) tensor_out = PipelineHelper.join([c[0] for c in chunks]) new_incr_state = PipelineHelper.join(new_incr_state) return tensor_out, new_incr_state
def test_schedule_work_items(self): # test that we schedule things correctly # pretend we have 8 layers and 4 gpus, and they are unevenly distributed model = torch.nn.ModuleList() for i in range(8): layer = IdentityLayer() if i == 0: layer._mp_gpu = 'cuda:0' elif i in (1, 2, 3): layer._mp_gpu = 'cuda:1' elif i in (4, 5): layer._mp_gpu = 'cuda:2' elif i in (6, 7): layer._mp_gpu = 'cuda:3' model.append(layer) # there are 2 chunks, each 16 x 7 in size chunks = PipelineHelper.split(torch.randn(32, 7), 16) work_items = list(PipelineHelper.schedule_work_items(model, chunks)) assert len(work_items) == 8 assert work_items[0].layer_nos == [0] and work_items[0].chunk_idx == 0 assert work_items[1].layer_nos == [1, 2, 3 ] and work_items[1].chunk_idx == 0 assert work_items[2].layer_nos == [0] and work_items[2].chunk_idx == 1 assert work_items[3].layer_nos == [4, 5 ] and work_items[3].chunk_idx == 0 assert work_items[4].layer_nos == [1, 2, 3 ] and work_items[4].chunk_idx == 1 assert work_items[5].layer_nos == [6, 7 ] and work_items[5].chunk_idx == 0 assert work_items[6].layer_nos == [4, 5 ] and work_items[6].chunk_idx == 1 assert work_items[7].layer_nos == [6, 7 ] and work_items[7].chunk_idx == 1
def test_split_tuple(self): t = torch.randn(32, 5) tup = (t, t, t) for stup in PipelineHelper.split(tup, 8): assert isinstance(stup, tuple) assert len(stup) == 3 for i in range(3): assert stup[i].shape == (8, 5)
def test_split_dict(self): t = torch.randn(32, 5) d = {'x': t, 'y': t} for sd in PipelineHelper.split(d, 8): assert isinstance(sd, dict) assert 'x' in sd assert 'y' in sd assert sd['x'].shape == (8, 5) assert sd['y'].shape == (8, 5)
def _apply_model_parallel_with_extra( self, tensor, encoder_output, encoder_mask, incr_state, extra_output: torch.Tensor = None, extra_mask: torch.Tensor = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Copy paste from TransformerDecoder._apply_model_parallel while incorporating the extra output/extra mask. """ chunks = PipelineHelper.split( (tensor, encoder_output, encoder_mask, incr_state, extra_output, extra_mask) ) work_items = PipelineHelper.schedule_work_items(self.layers, chunks) new_incr_state = {i: [] for i, _ in enumerate(self.layers)} for chunk_idx, layer_nos, next_device in work_items: s_tensor, s_enc_out, s_enc_mask, s_incr_state, s_extra_out, s_extra_mask = chunks[ chunk_idx ] for layer_no in layer_nos: s_tensor, nis = self.layers[layer_no]( x=s_tensor, encoder_output=s_enc_out, encoder_mask=s_enc_mask, incr_state=s_incr_state.get(layer_no), extra_output=s_extra_out, extra_mask=s_extra_mask, ) new_incr_state[layer_no].append(nis) # don't move incr state, it's always on the correct device s_tensor, s_enc_out, s_enc_mask, s_extra_out, s_extra_mask = PipelineHelper.chunk_to( (s_tensor, s_enc_out, s_enc_mask, s_extra_out, s_extra_mask), next_device, ) chunks[chunk_idx] = ( s_tensor, s_enc_out, s_enc_mask, s_incr_state, s_extra_out, s_extra_mask, ) tensor_out = PipelineHelper.join([c[0] for c in chunks]) new_incr_state = { layer_no: PipelineHelper.join(pieces) for layer_no, pieces in new_incr_state.items() } return tensor_out, new_incr_state # type: ignore
def test_split_complex(self): t = torch.randn(32, 5) item = (t, {'x': t, 'y': t}) for sitem in PipelineHelper.split(item, 8): assert isinstance(sitem, tuple) assert len(sitem) == 2 left, right = sitem assert isinstance(left, torch.Tensor) assert left.shape == (8, 5) assert isinstance(right, dict) assert 'x' in right assert 'y' in right assert right['x'].shape == (8, 5) assert right['y'].shape == (8, 5)
def _apply_model_parallel(self, tensor, mask): """ Pipeline application of model parallelism. """ chunks = PipelineHelper.split((tensor, mask)) work_items = PipelineHelper.schedule_work_items(self.layers, chunks) for chunk_idx, layer_nos, next_device in work_items: s_tensor, s_mask = chunks[chunk_idx] for layer_no in layer_nos: s_tensor = self.layers[layer_no](s_tensor, s_mask) chunks[chunk_idx] = PipelineHelper.chunk_to((s_tensor, s_mask), next_device) tensor_out, mask_out = PipelineHelper.join(chunks) return tensor_out
def test_split_emptydict(self): # test a horrible edge case where d is an empty dict, and we need to # return a BUNCH of empty dicts t = torch.randn(32, 5) d = {} tup = (t, d) items = PipelineHelper.split(tup, 8) assert len(items) == 4 for item in items: assert isinstance(item, tuple) a, b = item assert isinstance(a, torch.Tensor) assert a.shape == (8, 5) assert isinstance(b, dict) assert b == {}
def _apply_model_parallel(self, tensor, mask, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: """ Override to return attention weights. """ chunks = PipelineHelper.split((tensor, mask)) work_items = PipelineHelper.schedule_work_items(self.layers, chunks) for chunk_idx, layer_nos, next_device in work_items: s_weights = None try: s_tensor, s_mask = chunks[chunk_idx] except ValueError: s_tensor, s_mask, s_weights = chunks[chunk_idx] for layer_no in layer_nos: s_tensor, s_weights = self.layers[layer_no](s_tensor, s_mask, **kwargs) chunks[chunk_idx] = PipelineHelper.chunk_to( (s_tensor, s_mask, s_weights), next_device) joined = PipelineHelper.join(chunks) tensor_out, out_mask, weights = joined return tensor_out, weights