Exemplo n.º 1
0
def get_batch(data_iterator):
    """Generate a batch"""
    args = get_args()
    tokenizer = get_tokenizer()

    # Items and their type.
    keys = ['text']
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    tokens_ = data_b['text'].long()
    labels = tokens_[:, 1:].contiguous()
    tokens = tokens_[:, :-1].contiguous()

    # Get the masks and postition ids.
    attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
        tokens, tokenizer.eod, args.reset_position_ids,
        args.reset_attention_mask, args.eod_mask_loss)

    return tokens, labels, loss_mask, attention_mask, position_ids
Exemplo n.º 2
0
def get_batch_pipe(data):
    """A modification of get_batch() to work with the latest batch instead of an iterator. """
    args = get_args()
    tokenizer = get_tokenizer()

    # Items and their type.
    keys = ['text']
    datatype = torch.int64

    # Broadcast data.
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    tokens_ = data_b['text'].long()
    labels = tokens_[:, 1:].contiguous()
    tokens = tokens_[:, :-1].contiguous()

    # Get the masks and postition ids.
    attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
        tokens, tokenizer.eod, args.reset_position_ids,
        args.reset_attention_mask, args.eod_mask_loss)

    # unpack data
    if args.fp16:
        # cast to fp16 because pipeline parallelism skips the FP16 wrapper.
        return fp32_to_fp16(
            (tokens, position_ids, attention_mask)), fp32_to_fp16(
                (labels, loss_mask))
    else:
        return (tokens, position_ids, attention_mask), (labels, loss_mask)
Exemplo n.º 3
0
def get_batch(data_iterator):
    """Build the batch."""

    # Items and their type.
    keys = [
        'text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask'
    ]
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    tokens = data_b['text'].long()
    types = data_b['types'].long()
    sentence_order = data_b['is_random'].long()
    loss_mask = data_b['loss_mask'].float()
    lm_labels = data_b['labels'].long()
    padding_mask = data_b['padding_mask'].long()

    return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask
Exemplo n.º 4
0
def get_batch(data_iterator):
    """Build the batch."""

    keys = ['text_enc', 'text_dec', 'labels', 'loss_mask',
            'enc_mask', 'dec_mask', 'enc_dec_mask']
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    tokens_enc = data_b['text_enc'].long()
    tokens_dec = data_b['text_dec'].long()
    labels = data_b['labels'].long()
    loss_mask = data_b['loss_mask'].float()

    enc_mask = (data_b['enc_mask'] < 0.5)
    dec_mask = (data_b['dec_mask'] < 0.5)
    enc_dec_mask = (data_b['enc_dec_mask'] < 0.5)

    return tokens_enc, tokens_dec, loss_mask, labels, \
           enc_mask, dec_mask, enc_dec_mask
Exemplo n.º 5
0
def get_batch(data_iterator):
    """Build the batch."""

    tokenizer = get_tokenizer()
    # Items and their type.
    keys = ['text', 'labels', 'loss_mask', 'padding_mask']
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    tokens = data_b['text'].long()
    loss_mask = data_b['loss_mask'].float()
    lm_labels = data_b['labels'].long()
    padding_mask = data_b['padding_mask'].long()

    # Get the masks and postition ids.
    attention_mask, position_ids = get_tape_masks_and_position_ids(
        tokens,
        tokenizer.cls,
        reset_position_ids=True,
        reset_attention_mask=True)

    return tokens, loss_mask, lm_labels, padding_mask, attention_mask, position_ids
Exemplo n.º 6
0
def _get_batch(neox_args, tokenizer, keys, data, datatype):
    """Support function for get_batch / get_batch pipe (to avoid code repetition)"""
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    tokens_ = data_b['text'].long()
    labels = tokens_[:, 1:].contiguous()
    tokens = tokens_[:, :-1].contiguous()

    # Get the masks and postition ids.
    attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
        tokens, tokenizer.eod, neox_args.reset_position_ids,
        neox_args.reset_attention_mask, neox_args.eod_mask_loss)

    return tokens, labels, loss_mask, attention_mask, position_ids
def get_open_retrieval_batch(data_iterator):
    # Items and their type.
    keys = [
        'row_id', 'context', 'context_mask', 'context_types',
        'context_pad_mask'
    ]
    datatype = torch.int64

    # Broadcast data.
    data = None if data_iterator is None else next(data_iterator)
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    row_id = data_b['row_id'].long()
    context = data_b['context'].long()

    # TODO: make the context mask a binary one
    context_mask = (data_b['context_mask'] < 0.5)

    context_types = data_b['context_types'].long()
    context_pad_mask = data_b['context_pad_mask'].long()

    return row_id, context, context_mask, context_types, context_pad_mask
Exemplo n.º 8
0
def get_ict_batch(data_iterator):
    # Items and their type.
    keys = [
        'query_tokens', 'query_pad_mask', 'block_tokens', 'block_pad_mask',
        'block_data'
    ]
    datatype = torch.int64

    # Broadcast data.
    if data_iterator is None:
        data = None
    else:
        data = next(data_iterator)
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    query_tokens = data_b['query_tokens'].long()
    query_pad_mask = data_b['query_pad_mask'].long()
    block_tokens = data_b['block_tokens'].long()
    block_pad_mask = data_b['block_pad_mask'].long()
    block_indices = data_b['block_data'].long()

    return query_tokens, query_pad_mask,\
           block_tokens, block_pad_mask, block_indices