Пример #1
0
def loss_func(loss_mask, sentence_order, output_tensor):
    lm_loss_, sop_logits = output_tensor

    lm_loss_ = lm_loss_.float()
    loss_mask = loss_mask.float()
    lm_loss = torch.sum(
        lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

    if sop_logits is not None:
        sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
                                   sentence_order.view(-1),
                                   ignore_index=-1)
        sop_loss = sop_loss.float()
        loss = lm_loss + sop_loss
        averaged_losses = average_losses_across_data_parallel_group(
            [lm_loss, sop_loss])
        return loss, {
            'lm loss': averaged_losses[0],
            'sop loss': averaged_losses[1]
        }

    else:
        loss = lm_loss
        averaged_losses = average_losses_across_data_parallel_group([lm_loss])
        return loss, {'lm loss': averaged_losses[0]}
Пример #2
0
def _cross_entropy_forward_step(batch, model, input_tensor):
    """Simple forward step with cross-entropy loss."""
    timers = get_timers()

    # Get the batch.
    timers('batch-generator').start()
    try:
        batch_ = next(batch)
    except BaseException:
        batch_ = batch
    tokens, types, labels, attention_mask = process_batch(batch_)
    timers('batch-generator').stop()

    # Forward model.
    if mpu.is_pipeline_first_stage():
        assert input_tensor is None
        output_tensor = model(tokens, attention_mask, tokentype_ids=types)
    else:
        assert input_tensor is not None
        output_tensor = model(input_tensor, attention_mask)

    if mpu.is_pipeline_last_stage():
        logits = output_tensor

        # Cross-entropy loss.
        loss_func = torch.nn.CrossEntropyLoss()
        loss = loss_func(logits.contiguous().float(), labels)

        # Reduce loss for logging.
        averaged_loss = average_losses_across_data_parallel_group([loss])

        return loss, {'lm loss': averaged_loss[0]}
    return output_tensor
Пример #3
0
def forward_step(data_iterator, model, input_tensor):
    """Forward step."""
    args = get_args()
    timers = get_timers()

    # Get the batch.
    timers('batch-generator').start()
    query_tokens, query_mask, \
    context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
    timers('batch-generator').stop()

    # Query and Context Types
    query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
    context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)

    # Forward model.
    query_logits, context_logits = model(query_tokens, query_mask,
                                    query_types, context_tokens,
                                    context_mask, context_types)

    micro_batch_size = query_logits.shape[0]
    # recall we assert that tensor_model_parallel_size == 1
    assert mpu.get_tensor_model_parallel_world_size() == 1, \
        "Model parallel size > 1 not supported for ICT"

    global_batch_size = dist.get_world_size() * micro_batch_size
    all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
    all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits) 

    # scores are inner products between query and context embeddings
    retrieval_scores = torch.matmul(all_query_logits,
                        torch.transpose(all_context_logits, 0, 1))
    # scaling the retriever scores
    if args.retriever_score_scaling:
        retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)

    softmax_scores = F.log_softmax(retrieval_scores, dim=1)
    sorted_vals, sorted_indices = torch.topk(softmax_scores,
                                    k=softmax_scores.shape[1], sorted=True)

    def topk_accuracy(k):
        return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \
            for i in range(global_batch_size)]) / global_batch_size])

    topk_accs = [topk_accuracy(int(k)) for k in args.retriever_report_topk_accuracies]

    labels = torch.arange(global_batch_size).long().cuda()
    loss = F.nll_loss(softmax_scores, labels, reduction='mean')
    reduced_losses = average_losses_across_data_parallel_group([loss, *topk_accs])

    # Scale the retrieval loss
    loss = loss * mpu.get_data_parallel_world_size()

    # create stats_dict with retrieval loss and all specified top-k accuracies
    topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
                        zip(args.retriever_report_topk_accuracies, reduced_losses[1:])}
    stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict)
    return loss, stats_dict
Пример #4
0
def loss_func(loss_mask, output_tensor):
    losses = output_tensor.float()
    loss_mask = loss_mask.view(-1).float()
    loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

    # Reduce loss for logging.
    averaged_loss = average_losses_across_data_parallel_group([loss])

    return loss, {'lm loss': averaged_loss[0]}
Пример #5
0
def loss_func(loss_mask, output_tensor):
    lm_loss_ = output_tensor.float()
    lm_loss = torch.sum(
        lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

    loss = lm_loss
    averaged_losses = average_losses_across_data_parallel_group([lm_loss])

    return loss, {'lm loss': averaged_losses[0]}
Пример #6
0
def cross_entropy_loss_func(labels, output_tensor):
    logits = output_tensor

    # Cross-entropy loss.
    loss = F.cross_entropy(logits.contiguous().float(), labels)

    # Reduce loss for logging.
    averaged_loss = average_losses_across_data_parallel_group([loss])

    return loss, {'lm loss': averaged_loss[0]}
Пример #7
0
def forward_step(data_iterator, model, input_tensor):
    """Forward step."""
    args = get_args()
    timers = get_timers()

    # Get the batch.
    timers('batch-generator').start()
    tokens, loss_mask, lm_labels, padding_mask, attention_mask, position_ids \
        = get_batch(data_iterator)
    timers('batch-generator').stop()

    extended_attention_mask = bert_extended_attention_mask(
        padding_mask) + attention_mask

    # Forward pass through the model.
    if mpu.is_pipeline_first_stage():
        assert input_tensor is None
        if mpu.is_pipeline_last_stage():
            output_tensor = model(tokens,
                                  extended_attention_mask,
                                  tokentype_ids=None,
                                  lm_labels=lm_labels,
                                  position_ids=position_ids)
        else:
            output_tensor = model(tokens,
                                  extended_attention_mask,
                                  tokentype_ids=None)
    elif mpu.is_pipeline_last_stage():
        assert input_tensor is not None
        output_tensor = model(input_tensor,
                              extended_attention_mask,
                              lm_labels=lm_labels)
    else:
        assert input_tensor is not None
        output_tensor = model(input_tensor,
                              extended_attention_mask,
                              position_ids=position_ids)

    if mpu.is_pipeline_last_stage():
        lm_loss_, _ = output_tensor

        lm_loss_ = lm_loss_.float()
        loss_mask = loss_mask.float()
        lm_loss = torch.sum(
            lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

        loss = lm_loss

        averaged_losses = average_losses_across_data_parallel_group([
            lm_loss,
        ])

        return loss, {'lm loss': averaged_losses[0]}
    return output_tensor
Пример #8
0
def loss_func(labels, output_tensor):
    logits = output_tensor.contiguous().float()
    loss = F.cross_entropy(logits, labels)

    outputs = torch.argmax(logits, -1)
    correct = (outputs == labels).float()
    accuracy = torch.mean(correct)

    averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])

    return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
Пример #9
0
def forward_step(data_iterator, model, input_tensor):
    """Forward step."""
    args = get_args()
    timers = get_timers()

    # Get the batch.
    timers('batch-generator').start()
    tokens, types, sentence_order, loss_mask, lm_labels, padding_mask \
        = get_batch(data_iterator)
    timers('batch-generator').stop()

    # Forward pass through the model.
    if mpu.is_pipeline_first_stage():
        assert input_tensor is None
        if mpu.is_pipeline_last_stage():
            output_tensor = model(tokens,
                                  padding_mask,
                                  tokentype_ids=types,
                                  lm_labels=lm_labels)
        else:
            output_tensor = model(tokens, padding_mask, tokentype_ids=types)
    elif mpu.is_pipeline_last_stage():
        assert input_tensor is not None
        output_tensor = model(input_tensor, padding_mask, lm_labels=lm_labels)
    else:
        assert input_tensor is not None
        output_tensor = model(input_tensor, padding_mask)

    if mpu.is_pipeline_last_stage():
        lm_loss_, sop_logits = output_tensor

        sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
                                   sentence_order.view(-1),
                                   ignore_index=-1)
        sop_loss = sop_loss.float()

        lm_loss_ = lm_loss_.float()
        loss_mask = loss_mask.float()
        lm_loss = torch.sum(
            lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()

        loss = lm_loss + sop_loss

        averaged_losses = average_losses_across_data_parallel_group(
            [lm_loss, sop_loss])

        return loss, {
            'lm loss': averaged_losses[0],
            'sop loss': averaged_losses[1]
        }
    return output_tensor
Пример #10
0
def forward_step(data_iterator, model, input_tensor):
    """Forward step."""
    args = get_args()
    timers = get_timers()

    # Get the batch.
    timers('batch-generator').start()
    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
        data_iterator)
    timers('batch-generator').stop()

    # Forward pass through the model.
    if mpu.is_pipeline_first_stage():
        assert input_tensor is None
        if mpu.is_pipeline_last_stage():
            output_tensor = model(tokens,
                                  position_ids,
                                  attention_mask,
                                  labels=labels)
        else:
            output_tensor = model(tokens, position_ids, attention_mask)
    elif mpu.is_pipeline_last_stage():
        assert input_tensor is not None
        output_tensor = model(input_tensor, attention_mask, labels=labels)
    else:
        assert input_tensor is not None
        output_tensor = model(input_tensor, attention_mask)

    if mpu.is_pipeline_last_stage():
        losses = output_tensor.float()
        loss_mask = loss_mask.view(-1).float()
        loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

        # Reduce loss for logging.
        averaged_loss = average_losses_across_data_parallel_group([loss])

        return loss, {'lm loss': averaged_loss[0]}
    return output_tensor
Пример #11
0
def forward_step(data_iterator, model, input_tensor):
    """Forward step."""
    args = get_args()
    timers = get_timers()

    # Get the batch.
    timers('batch-generator').start()
    query_tokens, query_pad_mask, \
    block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator)
    timers('batch-generator').stop()


    # Forward model.
    query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
    micro_batch_size = query_logits.shape[0]
    global_batch_size = dist.get_world_size() * micro_batch_size  # recall we assert that tensor_model_parallel_size == 1

    all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
    all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)

    # scores are inner products between query and block embeddings
    retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float())
    softmaxed = F.softmax(retrieval_scores, dim=1)
    sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True)

    def topk_accuracy(k):
        return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size])

    topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies]
    retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
    retrieval_loss = retrieval_loss.float()
    averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs])

    # create stats_dict with retrieval loss and all specified top-k accuracies
    topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, averaged_losses[1:])}
    stats_dict = dict(retrieval_loss=averaged_losses[0], **topk_acc_dict)

    return retrieval_loss, stats_dict
Пример #12
0
def forward_step(data_iterator, model, input_tensor):
    """Forward step."""
    timers = get_timers()
    assert input_tensor is None

    # Get the batch.
    timers("batch-generator").start()
    (
        images,
        labels,
    ) = get_batch(data_iterator)
    timers("batch-generator").stop()

    # Forward model. lm_labels
    logits = model(images).contiguous().float()
    loss = F.cross_entropy(logits, labels)

    outputs = torch.argmax(logits, -1)
    correct = (outputs == labels).float()
    accuracy = torch.mean(correct)

    averaged_loss = average_losses_across_data_parallel_group([loss, accuracy])

    return loss, {"loss": averaged_loss[0], "accuracy": averaged_loss[1]}
Пример #13
0
def _cross_entropy_forward_step(batch, model, input_tensor):
    """Simple forward step with cross-entropy loss."""
    timers = get_timers()
    assert input_tensor is None

    # Get the batch.
    timers("batch generator").start()
    try:
        batch_ = next(batch)
    except BaseException:
        batch_ = batch
    images, labels = process_batch(batch_)
    timers("batch generator").stop()

    # Forward model.
    logits = model(images).contiguous().float()

    # Cross-entropy loss.
    loss = F.cross_entropy(logits, labels)

    # Reduce loss for logging.
    average_loss = average_losses_across_data_parallel_group([loss])

    return loss, {"lm loss": average_loss[0]}
Пример #14
0
    def cross_entropy_loss_func(query_tokens, context_tokens, output_tensor):
        args = get_args()

        local_batch_size = query_tokens.shape[0]
        group, rank, world_size = get_group_world_size_rank()
        # recall we assert that model_parallel_size == 1
        global_batch_size = world_size * local_batch_size

        query_logits, context_logits = output_tensor

        if world_size > 1:
            input_ = torch.empty_like(context_logits).copy_(\
                context_logits).detach_()
            tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
            tensor_list[rank].copy_(input_)
            torch.distributed.all_gather(tensor_list, input_, group=group)

            # Check if all-gather happens in order
            assert tensor_list[rank].sum().item() == \
                context_logits.sum().item()

            # Preserves the gradient
            tensor_list[rank] = context_logits
            all_context_logits = torch.cat(tensor_list, dim=0).contiguous()

            # Query tensors
            input_ = torch.empty_like(query_logits).copy_(\
                query_logits).detach_()
            tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
            tensor_list[rank].copy_(input_)
            torch.distributed.all_gather(tensor_list, input_, group=group)

            # Check if all-gather happens in order
            assert tensor_list[rank].sum().item() == query_logits.sum().item()

            # Preserves the gradient
            tensor_list[rank] = query_logits
            all_query_logits = torch.cat(tensor_list, dim=0).contiguous()
        else:
            all_query_logits = query_logits
            all_context_logits = context_logits

        retrieval_scores = torch.matmul(
            all_query_logits, torch.transpose(all_context_logits, 0, 1))
        # Scaling the retrieval scores
        if args.retriever_score_scaling:
            retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)

        if args.train_with_neg:
            # if the world size is 3, local batch size is 4, and
            # local context size is 8, what we want is
            # labels = [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19]
            labels = []
            local_context_size = context_tokens.shape[0]
            for i in range(world_size):
                j = i * local_context_size
                labels.extend(list(range(j, j + local_batch_size)))
            labels = torch.LongTensor(labels).cuda()
            assert len(labels) == global_batch_size
        else:
            labels = torch.arange(global_batch_size).long().cuda()

        # Cross-entropy loss.
        softmax_scores = F.log_softmax(retrieval_scores, dim=1)

        loss = F.nll_loss(softmax_scores, labels, reduction='mean')

        max_score, max_idxs = torch.max(softmax_scores, 1)
        correct_predictions_count = (max_idxs == labels).sum().float()

        # Reduce loss for logging.
        reduced_loss = average_losses_across_data_parallel_group([loss, \
            correct_predictions_count])

        # Loss scaling for correct losses in Supervised Retrieval
        loss = loss * mpu.get_data_parallel_world_size()

        return loss, {
            'lm loss': reduced_loss[0],
            'correct_prediction_count': reduced_loss[1]
        }
Пример #15
0
def retrieval_loss(model, dataloader):
    args = get_args()
    total = 0
    topk_stats_dict = {'top{}_acc'.format(k): 0 for k in \
        args.retriever_report_topk_accuracies}
    stats_dict = dict(rank=0, **topk_stats_dict)

    assert len(model) == 1
    unwrapped_model = model[0]
    unwrapped_model.eval()

    with torch.no_grad():
        # For all the batches in the dataset.
        for batch in dataloader:
            # Run the model forward.
            query_tokens, query_mask, query_types, _, \
            context_tokens, context_mask, context_types, _, \
            neg_context_tokens, neg_context_mask, neg_context_types, \
            reference = process_batch(batch)

            query_logits, context_logits = unwrapped_model(query_tokens,
                query_mask, query_types,
                torch.cat([context_tokens, neg_context_tokens]),
                torch.cat([context_mask, neg_context_mask]),
                torch.cat([context_types, neg_context_types]))

            retrieval_scores = torch.matmul(query_logits,
                                    torch.transpose(context_logits, 0, 1))

            if args.retriever_score_scaling:
                retrieval_scores = retrieval_scores / \
                    math.sqrt(args.hidden_size)

            local_batch_size = query_logits.shape[0]
            labels = torch.arange(local_batch_size).long().cuda()

            softmax_scores = F.softmax(retrieval_scores, dim=1)
            sorted_vals, sorted_indices = torch.topk(softmax_scores,
                                                     k=softmax_scores.shape[1],
                                                     sorted=True)

            def topk_accuracy(k):
                return torch.cuda.FloatTensor(
                    [sum([int(labels[i] in sorted_indices[i, :k]) for i in \
                        range(local_batch_size)])])

            def get_rank():
                return torch.cuda.FloatTensor(
                    [sum([torch.nonzero(labels[i] == sorted_indices[i])[0][0] \
                        for i in range(local_batch_size)])])

            topk_accs = [topk_accuracy(k) for k in \
                args.retriever_report_topk_accuracies]
            rank = get_rank()
            losses = average_losses_across_data_parallel_group([rank, \
                *topk_accs])

            # create stats_dict with retrieval loss and all specified
            # top-k accuracies
            topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
                zip(args.retriever_report_topk_accuracies, losses[1:])}
            temp_stats_dict = dict(rank=losses[0], **topk_acc_dict)
            for k in stats_dict.keys():
                stats_dict[k] += temp_stats_dict[k]
            total += local_batch_size

    unwrapped_model.train()

    return stats_dict, total