def get_batch(data_iterator, args, timers):
    ''' get_batch subdivides the source data into chunks of
    length args.seq_length. If source is equal to the example
    output of the data loading example, with a seq_length limit
    of 2, we'd get the following two Variables for i = 0:
    ┌ a g m s ┐ ┌ b h n t ┐
    └ b h n t ┘ └ c i o u ┘
    Note that despite the name of the function, the subdivison of data is not
    done along the batch dimension (i.e. dimension 1), since that was handled
    by the data loader. The chunks are along dimension 0, corresponding
    to the seq_len dimension in the LSTM. A Variable representing an appropriate
    shard reset mask of the same dimensions is also returned.
    '''
    # Items and their type.
    keys = ['text', 'types', 'mask', 'mask_labels', 'pad_mask', 'sample_id']
    datatype = torch.int64
    keys2 = ['clickscores', 'hrsscores']
    datatype2 = torch.float64
    # Broadcast data.
    timers('data loader').start()
    #if torch.distributed.get_rank() == 0:
    #    print("CCCCCCCCCCCCCCCCCCCCCCCCCC")
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    #if torch.distributed.get_rank() == 0:
    #    print("DDDDDDDDDDDDDDDDDDDDDDDDD")
    timers('data loader').stop()
    data_b = mpu.broadcast_data(keys, data, datatype)
    data_b2 = mpu.broadcast_data(keys2, data, datatype2)

    # Unpack.

    tokens = data_b['text'].long()
    batch_size, num_urls, seq_length = tokens.size()
    tokens = data_b['text'].view(-1, seq_length).long()
    types = data_b['types'].view(-1, seq_length).long()
    #if torch.distributed.get_rank() == 0:
    #    print("tokens= ", tokens[0:4,:].detach().cpu().numpy())
    loss_mask = data_b['mask'].view(-1, seq_length).float()
    lm_labels = data_b['mask_labels'].view(-1, seq_length).long()
    #if torch.distributed.get_rank() == 0:
    #    print("lm_labels= ", lm_labels[0:4,:].detach().cpu().numpy())
    padding_mask = data_b['pad_mask'].view(-1, seq_length).float()
    clickscores = data_b2['clickscores'].view(batch_size, num_urls).float()
    hrsscores = data_b2['hrsscores'].view(batch_size, num_urls).float()
    sample_id = data_b['sample_id'].view(batch_size).long()
    # Get the masks and postition ids.
    batch_size, seq_length = tokens.size()
    attention_mask = (torch.ones_like(padding_mask, device=padding_mask.device) - padding_mask).view(batch_size, 1, seq_length, 1) * (torch.ones_like(padding_mask, device=padding_mask.device) - padding_mask).view(batch_size, 1, 1, seq_length)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=tokens.device)
    position_ids = position_ids.unsqueeze(0).expand_as(tokens)
    # Convert
    if args.fp16:
        attention_mask = attention_mask.half()
    
    #types = torch.zeros_like(tokens, dtype=torch.long, device=types.device)
    return tokens, types, loss_mask, lm_labels, padding_mask, attention_mask, position_ids, clickscores, hrsscores, sample_id
Beispiel #2
0
def get_batch(data_iterator, timers):
    ''' get_batch subdivides the source data into chunks of
    length args.seq_length. If source is equal to the example
    output of the data loading example, with a seq_length limit
    of 2, we'd get the following two Variables for i = 0:
    ┌ a g m s ┐ ┌ b h n t ┐
    └ b h n t ┘ └ c i o u ┘
    Note that despite the name of the function, the subdivison of data is not
    done along the batch dimension (i.e. dimension 1), since that was handled
    by the data loader. The chunks are along dimension 0, corresponding
    to the seq_len dimension in the LSTM. A Variable representing an appropriate
    shard reset mask of the same dimensions is also returned.
    '''
    # Items and their type.
    keys = ['text', 'types', 'is_random', 'mask', 'mask_labels', 'pad_mask']
    datatype = torch.int64

    # Broadcast data.
    timers('data loader').start()
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    timers('data loader').stop()
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    tokens = data_b['text'].long()
    types = data_b['types'].long()
    next_sentence = data_b['is_random'].long()
    loss_mask = data_b['mask'].float()
    lm_labels = data_b['mask_labels'].long()
    padding_mask = data_b['pad_mask'].byte()

    return tokens, types, next_sentence, loss_mask, lm_labels, padding_mask
Beispiel #3
0
def get_two_batch(data, args):
    keys = ['text', 'target', 'loss_mask']
    datatype = torch.int64
    # Broadcast data.
    data_b = mpu.broadcast_data(keys, data, datatype)
    source_tokens = data_b['text'].long()
    target_tokens = data_b['target'].long()
    loss_mask = data_b['loss_mask'].float()
    labels = target_tokens[:, 1:].contiguous()
    loss_mask = loss_mask[:, 1:].contiguous()
    target_tokens = target_tokens[:, :-1].contiguous()
    _, _, source_position_ids = get_masks_and_position_ids(
        source_tokens,
        args.eod_token,
        reset_position_ids=False,
        reset_attention_mask=False,
        loss_mask=None,
        attention_mask=None,
        set_loss_mask=False)
    target_mask, _, target_position_ids = get_masks_and_position_ids(
        target_tokens,
        args.eod_token,
        reset_position_ids=False,
        reset_attention_mask=False,
        loss_mask=None,
        attention_mask=None,
        set_loss_mask=False)
    if args.fp16:
        target_mask = target_mask.half()
    return source_tokens, target_tokens, source_position_ids, target_position_ids, labels, target_mask, loss_mask
Beispiel #4
0
def process_batch(batch, args):
    """Process batch and produce inputs for the model."""
    keys = ["text", "label"]
    if args.pretrained_bert:
        keys += ["padding_mask", "types"]
    else:
        keys += ["mask", "position"]
        if args.cloze_eval:
            if args.fast_decode:
                keys += [
                    "dec_text", "dec_position", "dec_mask", "dec_target",
                    "dec_logit_mask"
                ]
            else:
                keys += ["target", "logit_mask"]
                if args.segment_length > 0:
                    keys += ["segment_id"]
                if args.continuous_prompt:
                    keys += ["prompt_pos"]
    if args.variable_num_choices:
        keys.append("loss_mask")
    # Broadcast data.
    datatype = torch.int64
    data_b = mpu.broadcast_data(keys, batch, datatype)

    if "padding_mask" in data_b:
        attention_mask = data_b['padding_mask'].float().cuda().contiguous()
        if args.fp16:
            attention_mask = attention_mask.half()
        data_b["padding_mask"] = attention_mask
    return data_b
Beispiel #5
0
def get_batch(data, args):
    ''' get_batch subdivides the source data into chunks of
    length args.seq_length. If source is equal to the example
    output of the data loading example, with a seq_length limit
    of 2, we'd get the following two Variables for i = 0:
    ┌ a g m s ┐ ┌ b h n t ┐
    └ b h n t ┘ └ c i o u ┘
    Note that despite the name of the function, the subdivison of data is not
    done along the batch dimension (i.e. dimension 1), since that was handled
    by the data loader. The chunks are along dimension 0, corresponding
    to the seq_len dimension in the LSTM. A Variable representing an appropriate
    shard reset mask of the same dimensions is also returned.
    '''
    # Items and their type.
    keys = ['text', 'loss_mask']
    if args.transformer_xl or args.block_lm:
        keys += ['target', 'attention_mask']
    if args.block_lm:
        keys += ['position_id']
    datatype = torch.int64

    # Broadcast data.
    data_b = mpu.broadcast_data(keys, data, datatype)
    # Unpack.
    if args.transformer_xl:
        tokens = data_b['text'].long()
        labels = data_b['target'].long()
        attention_mask = data_b['attention_mask'].float()
        loss_mask = data_b['loss_mask'].float()
    elif args.block_lm:
        tokens = data_b['text'].long()
        labels = data_b['target'].long()
        attention_mask = data_b['attention_mask'].long()
        loss_mask = data_b['loss_mask'].float()
        position_ids = data_b['position_id'].long()
    else:
        tokens_ = data_b['text'].long()
        loss_mask = data_b['loss_mask'].float()
        labels = tokens_[:, 1:].contiguous()
        loss_mask = loss_mask[:, 1:].contiguous()
        tokens = tokens_[:, :-1].contiguous()
        attention_mask = None

    # Get the masks and postition ids.
    if not args.block_lm:
        attention_mask, loss_mask, position_ids = get_masks_and_position_ids(
            tokens,
            args.eod_token,
            args.reset_position_ids,
            args.reset_attention_mask,
            loss_mask=loss_mask,
            attention_mask=attention_mask,
            mem_length=args.mem_length,
            set_loss_mask=not args.transformer_xl)
        # Convert
        if args.fp16:
            attention_mask = attention_mask.half()
    return tokens, labels, loss_mask, attention_mask, position_ids
def get_batch(data_iterator, args, timers):
    ''' get_batch subdivides the source data into chunks of
    length args.seq_length. If source is equal to the example
    output of the data loading example, with a seq_length limit
    of 2, we'd get the following two Variables for i = 0:
    ┌ a g m s ┐ ┌ b h n t ┐
    └ b h n t ┘ └ c i o u ┘
    Note that despite the name of the function, the subdivison of data is not
    done along the batch dimension (i.e. dimension 1), since that was handled
    by the data loader. The chunks are along dimension 0, corresponding
    to the seq_len dimension in the LSTM. A Variable representing an appropriate
    shard reset mask of the same dimensions is also returned.
    '''
    # Items and their type.
    keys = ['text', 'pad_mask']
    datatype = torch.int64

    # Broadcast data.
    timers('data loader').start()
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None
    timers('data loader').stop()
    data_b = mpu.broadcast_data(keys, data, datatype)

    # Unpack.
    tokens_ = data_b['text'].long()
    lm_labels = tokens_[:, 1:].contiguous()
    tokens = tokens_[:, :-1].contiguous()
    padding_mask = data_b['pad_mask'].byte()

    # Get the masks and postition ids.
    attention_mask, loss_mask, position_ids = get_masks_and_position_ids(
        tokens,
        args.eod_token,
        args.reset_position_ids,
        args.reset_attention_mask)

    # Convert
    if args.fp16:
        attention_mask = attention_mask.half()

    return tokens, lm_labels, attention_mask, position_ids, padding_mask