Пример #1
0
def seq2seq_forward_step(data, model, args, timers, mems):
    """Forward step."""

    # Get the batch.
    if timers is not None:
        timers('batch generator').start()
    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
        data, args)
    if timers is not None:
        timers('batch generator').stop()
    # Forward model.
    logits, *mems = model(tokens, position_ids, attention_mask, *mems)
    logits, loss_mask = logits[:, args.
                               src_seq_length:], loss_mask[:, args.
                                                           src_seq_length:]
    labels = labels[:, args.src_seq_length:]
    losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(),
                                              labels)
    if args.label_smoothing > 0.0:
        epsilon = args.label_smoothing
        smooth_loss = -torch.nn.functional.log_softmax(logits,
                                                       dim=-1).mean(dim=-1)
        losses = (1 - epsilon) * losses + epsilon * smooth_loss
    loss_mask = loss_mask.reshape(-1)
    # The loss is not normalized for fair comparison
    loss = torch.sum(losses.reshape(-1) * loss_mask) / loss_mask.sum()
    return loss, mems, 'bert'
Пример #2
0
def lm_forward_step(data, model, args, timers, mems, eval_metric=None):
    """Forward step."""

    # Get the batch.
    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
        data, args)

    def print_masked_text(batch_id):
        block_position_ids = position_ids[:, 1]
        position_ids_ = position_ids[:, 0]
        output_tokens = []
        sep = attention_mask[batch_id].item()
        for i, token in enumerate(tokens[batch_id, :sep].tolist()):
            if global_tokenizer is not None:
                token = global_tokenizer.IdToToken(token)
                if token.startswith('[MASK'):
                    token = f"[{position_ids_[batch_id, i].item()}, {token}]"
                if token.startswith('##') and len(
                        output_tokens) > 0 and not output_tokens[-1].endswith(
                            ']'):
                    output_tokens[-1] += token[2:]
                else:
                    output_tokens.append(token)
            else:
                output_tokens.append(str(token))
        print(" ".join(output_tokens))
        last_index = None
        for i in range(sep, tokens.size(1)):
            if global_tokenizer.IdToToken(
                    tokens[batch_id, i].item()).startswith("<|startofpiece"):
                if last_index is not None:
                    print(
                        global_tokenizer.DecodeIds(
                            tokens[batch_id, last_index:i].tolist()), "|",
                        global_tokenizer.DecodeIds(
                            labels[batch_id, last_index:i].tolist())),
                    print(position_ids_[batch_id, last_index:i].tolist(),
                          block_position_ids[batch_id, last_index:i].tolist())
                last_index = i
        if last_index is not None:
            print(
                global_tokenizer.DecodeIds(tokens[batch_id,
                                                  last_index:].tolist()), "|",
                global_tokenizer.DecodeIds(labels[batch_id,
                                                  last_index:].tolist()))
            print(position_ids_[batch_id, last_index:].tolist(),
                  block_position_ids[batch_id, last_index:].tolist())

    # Forward model.
    if args.continuous_prompt:
        prompt_pos = data["prompt_pos"].long().cuda()
        logits, *mems = model(tokens,
                              position_ids,
                              attention_mask,
                              *mems,
                              prompt_pos=prompt_pos)
    else:
        logits, *mems = model(tokens, position_ids, attention_mask, *mems)

    if eval_metric is None or eval_metric == 'loss':
        losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(),
                                                  labels)
        loss_mask = loss_mask.view(-1)
        # The loss is not normalized for fair comparison
        loss = torch.sum(losses.view(-1) * loss_mask)
        if eval_metric is None:
            loss = loss / loss_mask.sum()
        return loss, mems, 'bert'
    elif eval_metric == 'accuracy' or eval_metric == 'classify':
        outputs = torch.argmax(logits, -1)
        correct = (outputs == labels).float()
        correct[(1 - loss_mask).bool()] = 1
        correct = correct.prod(-1)
        if eval_metric == 'accuracy':
            correct = correct.sum()
        return correct, mems, 'bert'
    else:
        raise NotImplementedError(
            "Metric {} not implemented".format(eval_metric))