def get_batch(data_iterator): """Generate a batch""" args = get_args() tokenizer = get_tokenizer() # Items and their type. keys = ['text'] datatype = torch.int64 # Broadcast data. if data_iterator is not None: data = next(data_iterator) else: data = None data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. tokens_ = data_b['text'].long() labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss) return tokens, labels, loss_mask, attention_mask, position_ids
def get_batch_pipe(data): """A modification of get_batch() to work with the latest batch instead of an iterator. """ args = get_args() tokenizer = get_tokenizer() # Items and their type. keys = ['text'] datatype = torch.int64 # Broadcast data. data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. tokens_ = data_b['text'].long() labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss) # unpack data if args.fp16: # cast to fp16 because pipeline parallelism skips the FP16 wrapper. return fp32_to_fp16( (tokens, position_ids, attention_mask)), fp32_to_fp16( (labels, loss_mask)) else: return (tokens, position_ids, attention_mask), (labels, loss_mask)
def get_batch(data_iterator): """Build the batch.""" # Items and their type. keys = [ 'text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask' ] datatype = torch.int64 # Broadcast data. if data_iterator is not None: data = next(data_iterator) else: data = None data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. tokens = data_b['text'].long() types = data_b['types'].long() sentence_order = data_b['is_random'].long() loss_mask = data_b['loss_mask'].float() lm_labels = data_b['labels'].long() padding_mask = data_b['padding_mask'].long() return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
def get_batch(data_iterator): """Build the batch.""" keys = ['text_enc', 'text_dec', 'labels', 'loss_mask', 'enc_mask', 'dec_mask', 'enc_dec_mask'] datatype = torch.int64 # Broadcast data. if data_iterator is not None: data = next(data_iterator) else: data = None data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. tokens_enc = data_b['text_enc'].long() tokens_dec = data_b['text_dec'].long() labels = data_b['labels'].long() loss_mask = data_b['loss_mask'].float() enc_mask = (data_b['enc_mask'] < 0.5) dec_mask = (data_b['dec_mask'] < 0.5) enc_dec_mask = (data_b['enc_dec_mask'] < 0.5) return tokens_enc, tokens_dec, loss_mask, labels, \ enc_mask, dec_mask, enc_dec_mask
def get_batch(data_iterator): """Build the batch.""" tokenizer = get_tokenizer() # Items and their type. keys = ['text', 'labels', 'loss_mask', 'padding_mask'] datatype = torch.int64 # Broadcast data. if data_iterator is not None: data = next(data_iterator) else: data = None data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. tokens = data_b['text'].long() loss_mask = data_b['loss_mask'].float() lm_labels = data_b['labels'].long() padding_mask = data_b['padding_mask'].long() # Get the masks and postition ids. attention_mask, position_ids = get_tape_masks_and_position_ids( tokens, tokenizer.cls, reset_position_ids=True, reset_attention_mask=True) return tokens, loss_mask, lm_labels, padding_mask, attention_mask, position_ids
def _get_batch(neox_args, tokenizer, keys, data, datatype): """Support function for get_batch / get_batch pipe (to avoid code repetition)""" data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. tokens_ = data_b['text'].long() labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, tokenizer.eod, neox_args.reset_position_ids, neox_args.reset_attention_mask, neox_args.eod_mask_loss) return tokens, labels, loss_mask, attention_mask, position_ids
def get_open_retrieval_batch(data_iterator): # Items and their type. keys = [ 'row_id', 'context', 'context_mask', 'context_types', 'context_pad_mask' ] datatype = torch.int64 # Broadcast data. data = None if data_iterator is None else next(data_iterator) data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. row_id = data_b['row_id'].long() context = data_b['context'].long() # TODO: make the context mask a binary one context_mask = (data_b['context_mask'] < 0.5) context_types = data_b['context_types'].long() context_pad_mask = data_b['context_pad_mask'].long() return row_id, context, context_mask, context_types, context_pad_mask
def get_ict_batch(data_iterator): # Items and their type. keys = [ 'query_tokens', 'query_pad_mask', 'block_tokens', 'block_pad_mask', 'block_data' ] datatype = torch.int64 # Broadcast data. if data_iterator is None: data = None else: data = next(data_iterator) data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. query_tokens = data_b['query_tokens'].long() query_pad_mask = data_b['query_pad_mask'].long() block_tokens = data_b['block_tokens'].long() block_pad_mask = data_b['block_pad_mask'].long() block_indices = data_b['block_data'].long() return query_tokens, query_pad_mask,\ block_tokens, block_pad_mask, block_indices