def get_batch_pipe(data): """A modification of get_batch() to work with the latest batch instead of an iterator. """ args = get_args() tokenizer = get_tokenizer() # Items and their type. keys = ['text'] datatype = torch.int64 # Broadcast data. data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. tokens_ = data_b['text'].long() labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss) # unpack data if args.fp16: # cast to fp16 because pipeline parallelism skips the FP16 wrapper. return fp32_to_fp16( (tokens, position_ids, attention_mask)), fp32_to_fp16( (labels, loss_mask)) else: return (tokens, position_ids, attention_mask), (labels, loss_mask)
def get_batch(data_iterator): """Generate a batch""" args = get_args() tokenizer = get_tokenizer() # Items and their type. keys = ['text'] datatype = torch.int64 # Broadcast data. if data_iterator is not None: data = next(data_iterator) else: data = None data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. tokens_ = data_b['text'].long() labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss) return tokens, labels, loss_mask, attention_mask, position_ids
def get_batch(context_tokens): """Generate batch from context tokens.""" args = get_args() tokenizer = get_tokenizer() # Move to GPU. tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda() # Get the attention mask and postition ids. attention_mask, _, position_ids = get_ltor_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss) return tokens, attention_mask, position_ids
def _build_attention_mask_and_position_ids(tokens): """Build the attention mask and postition ids for the input tokens.""" # Since we are not interested in loss-mask and reset attention/position # is also False, eod_token is not used so it is safe to set it to None. attention_mask, _, position_ids = get_ltor_masks_and_position_ids( data=tokens, eod_token=None, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False) return attention_mask, position_ids
def _get_batch(neox_args, tokenizer, keys, data, datatype): """Support function for get_batch / get_batch pipe (to avoid code repetition)""" data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. tokens_ = data_b['text'].long() labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, tokenizer.eod, neox_args.reset_position_ids, neox_args.reset_attention_mask, neox_args.eod_mask_loss) return tokens, labels, loss_mask, attention_mask, position_ids
def process_batch(batch): """Process batch and produce inputs for the model.""" args = get_args() tokenizer = get_tokenizer() loss_mask = batch['pad_mask'].long().cuda().contiguous().byte() tokens_ = batch['text'].long().cuda().contiguous() labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. attention_mask, _, position_ids = get_ltor_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss) return tokens, labels, attention_mask, position_ids, loss_mask
def get_batch(neox_args, context_tokens: torch.Tensor): """ Generate batch from context tokens. Attention mask and position ids are created. Returned tensors will be on CUDA. neox_args: NeoXArgs with tokenizer, reset_position_ids, reset_attention_mask and eod_mask_loss context_tokens: torch tensor with dimensions [batch, context_size] returns: tuple of torch tensors (tokens, attention_mask, position_ids) on CUDA """ # Move to GPU. tokens = context_tokens.contiguous().cuda() # Get the attention mask and postition ids. attention_mask, _, position_ids = get_ltor_masks_and_position_ids( tokens, neox_args.tokenizer.eod, neox_args.reset_position_ids, neox_args.reset_attention_mask, neox_args.eod_mask_loss) return tokens, attention_mask, position_ids