Exemplo n.º 1
0
def squad_features(
        context: str, question: str, answer: Union[str, None],
        start_char_pos: Union[int, None],
        tokenizer: BertTokenizerFast) -> Tuple[List[int], List[int], int, int]:
    """ Squad feature extractor
    Implement the feature extractor from a Squad sample for your model
    Return values should follow [CLS + question + SEP + context + SEP] form.
    In addition, because start_char_pos is based on character index, you should convert it to proper token index.
    Check the test cases to know the functionality in detail.

    Note: input_ids and token_type_ids follows the transfomer library documentation 
    https://huggingface.co/transformers/glossary.html

    Arguments:
    context -- Context string
    question -- Question string
    answer -- Answer string. If the answer is None, return None for start_token_pos and end_token_pos
    start_char_pos -- Character index which the answer starts from in the context.
                      If the answer is None, this argument is also None.
    tokenizer -- Tokenizer to encode text strings.
                 Explanation: https://huggingface.co/transformers/model_doc/bert.html#berttokenizerfast

    Returns:
    input_ids -- Input ids
    token_type_ids -- Token type ids
    start_token_pos -- Token index which the answer starts from in the input_ids list. 
                       None if no answer is given.
    end_token_pos -- Token index which the answer ends by in the input_ids list.
                     This includes the last token which located in the index.
                     None if no answer is given.
    """
    ### YOUR CODE HERE (~18 lines)
    encoded_dict = tokenizer.encode_plus(question, context)
    input_ids = encoded_dict["input_ids"]
    token_type_ids = encoded_dict["token_type_ids"]
    input_ids_tokens = tokenizer.convert_ids_to_tokens(input_ids)
    # print("Input (tokens): ", input_ids_tokens)
    if answer is None and start_char_pos is None:
        start_token_pos = None
        end_token_pos = None
        return input_ids, token_type_ids, start_token_pos, end_token_pos

    start_token_pos, end_token_pos = 0, 0
    start_token_pos += token_type_ids.count(0)
    start_token_pos += len(tokenizer.tokenize(context[:start_char_pos]))
    end_token_pos += len(tokenizer.tokenize(answer)) + start_token_pos - 1
    # Extract tokenized answer part only
    tokenized_answer = " ".join(
        tokenizer.convert_ids_to_tokens(
            input_ids[start_token_pos:end_token_pos + 1]))

    subword_prefix_original = "##" if "##" in tokenized_answer else ""
    subword_prefix = "##"
    tokenized_answer = tokenized_answer.replace('#', '')
    if tokenized_answer != answer.lower(
    ) and start_token_pos == end_token_pos and answer in tokenized_answer:
        # A single word but different subword tokenization case
        new_subword_list = [
            subword_prefix_original + tokenized_answer[:len(answer)],
            subword_prefix + tokenized_answer[len(answer):]
        ]
        # print('new_subword_list : ', new_subword_list)
        input_ids = input_ids[:
                              start_token_pos] + tokenizer.convert_tokens_to_ids(
                                  new_subword_list) + input_ids[end_token_pos +
                                                                1:]
        token_type_ids.append(1)

    # print("Input ids: ", input_ids)
    # input_ids_tokens = tokenizer.convert_ids_to_tokens(input_ids)
    # print("Input (tokens) (ADJUSTED): ", input_ids_tokens)
    # print("Segmend Ids: ", token_type_ids)
    # print('START_CHAR_POS: ', start_char_pos)
    # print("ANSWER: ", answer)
    # print("START: ", start_token_pos)
    # print("END: ", end_token_pos)
    # print("ANSWER SPAN: ", input_ids_tokens[start_token_pos:end_token_pos+1])
    assert len(input_ids) == len(token_type_ids)

    ### END YOUR CODE

    return input_ids, token_type_ids, start_token_pos, end_token_pos
Exemplo n.º 2
0
def main():
    args = set_args()
    logger = create_logger(args)
    # 当用户使用GPU,并且GPU可用时
    args.cuda = torch.cuda.is_available() and not args.no_cuda
    device = 'cuda' if args.cuda else 'cpu'
    logger.info('using device:{}'.format(device))
    os.environ["CUDA_VISIBLE_DEVICES"] = args.device
    tokenizer = BertTokenizerFast(vocab_file=args.vocab_path,
                                  sep_token="[SEP]",
                                  pad_token="[PAD]",
                                  cls_token="[CLS]")
    # tokenizer = BertTokenizer(vocab_file=args.voca_path)
    model = GPT2LMHeadModel.from_pretrained(args.model_path)
    model = model.to(device)
    model.eval()
    if args.save_samples_path:
        if not os.path.exists(args.save_samples_path):
            os.makedirs(args.save_samples_path)
        samples_file = open(args.save_samples_path + '/samples.txt',
                            'a',
                            encoding='utf8')
        samples_file.write("聊天记录{}:\n".format(datetime.now()))
    # 存储聊天记录,每个utterance以token的id的形式进行存储
    history = []
    print('开始和chatbot聊天,输入CTRL + Z以退出')

    while True:
        try:
            text = input("user:"******"你好"
            if args.save_samples_path:
                samples_file.write("user:{}\n".format(text))
            text_ids = tokenizer.encode(text, add_special_tokens=False)
            history.append(text_ids)
            input_ids = [tokenizer.cls_token_id]  # 每个input以[CLS]为开头

            for history_id, history_utr in enumerate(
                    history[-args.max_history_len:]):
                input_ids.extend(history_utr)
                input_ids.append(tokenizer.sep_token_id)
            input_ids = torch.tensor(input_ids).long().to(device)
            input_ids = input_ids.unsqueeze(0)
            response = []  # 根据context,生成的response
            # 最多生成max_len个token
            for _ in range(args.max_len):
                outputs = model(input_ids=input_ids)
                logits = outputs.logits
                next_token_logits = logits[0, -1, :]
                # 对于已生成的结果generated中的每个token添加一个重复惩罚项,降低其生成概率
                for id in set(response):
                    next_token_logits[id] /= args.repetition_penalty
                next_token_logits = next_token_logits / args.temperature
                # 对于[UNK]的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token
                next_token_logits[tokenizer.convert_tokens_to_ids(
                    '[UNK]')] = -float('Inf')
                filtered_logits = top_k_top_p_filtering(next_token_logits,
                                                        top_k=args.topk,
                                                        top_p=args.topp)
                # torch.multinomial表示从候选集合中无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标
                next_token = torch.multinomial(F.softmax(filtered_logits,
                                                         dim=-1),
                                               num_samples=1)
                if next_token == tokenizer.sep_token_id:  # 遇到[SEP]则表明response生成结束
                    break
                response.append(next_token.item())
                input_ids = torch.cat((input_ids, next_token.unsqueeze(0)),
                                      dim=1)
                # his_text = tokenizer.convert_ids_to_tokens(curr_input_tensor.tolist())
                # print("his_text:{}".format(his_text))
            history.append(response)
            text = tokenizer.convert_ids_to_tokens(response)
            print("chatbot:" + "".join(text))
            if args.save_samples_path:
                samples_file.write("chatbot:{}\n".format("".join(text)))
        except KeyboardInterrupt:
            if args.save_samples_path:
                samples_file.close()
            break
Exemplo n.º 3
0
def main(args):
    torch.cuda.set_device(args.local_rank)
    world_size = int(os.getenv('WORLD_SIZE', 1))
    if world_size > 1:
        torch.distributed.init_process_group(
            backend='nccl',
            init_method='env://',
        )

    if get_rank() == 0 and args.seq_len_dir is not None:
        mkdir(args.seq_len_dir)

    loader = get_bert_pretrain_data_loader(
        args.path,
        local_rank=args.local_rank,
        shuffle_buffer_size=args.shuffle_buffer_size,
        shuffle_buffer_warmup_factor=args.shuffle_buffer_warmup_factor,
        vocab_file=args.vocab_file,
        data_loader_kwargs={
            'batch_size': args.batch_size,
            'num_workers': args.workers,
            'prefetch_factor': args.prefetch
        },
        mlm_probability=args.mlm_probability,
        base_seed=args.seed,
        log_dir=args.log_dir,
        log_level=getattr(logging, args.log_level),
        return_raw_samples=args.debug,
        start_epoch=args.start_epoch,
        sequence_length_alignment=args.sequence_length_alignment,
        ignore_index=args.ignore_index,
    )
    if os.path.isfile(args.vocab_file):
        test_tokenizer = BertTokenizerFast(args.vocab_file)
    else:
        test_tokenizer = BertTokenizerFast.from_pretrained(args.vocab_file)

    meter = AverageMeter(warmup=args.warmup)

    lens_shape = (args.epochs, min(len(loader), args.iters_per_epoch))
    min_lens, max_lens, batch_sizes, padded_lens = (
        np.zeros(lens_shape, dtype=np.uint16),
        np.zeros(lens_shape, dtype=np.uint16),
        np.zeros(lens_shape, dtype=np.uint16),
        np.zeros(lens_shape, dtype=np.uint16),
    )
    seq_len_hist = Histogram()
    padded_zero_hist = Histogram()

    for epoch in range(args.start_epoch, args.start_epoch + args.epochs):
        barrier()
        epoch_timer_start = time.time()
        batch_timer_start = time.time()
        total_samples = 0
        for i, data in enumerate(loader):
            if i >= args.iters_per_epoch:
                break
            if not args.debug:
                (input_ids, token_type_ids, attention_mask, labels,
                 next_sentence_labels) = (
                     data['input_ids'],
                     data['token_type_ids'],
                     data['attention_mask'],
                     data['labels'],
                     data['next_sentence_labels'],
                 )
            batch_timer_stop = time.time()
            elapsed = batch_timer_stop - batch_timer_start
            meter.update(elapsed)

            if args.debug:
                current_samples = len(data[0]) * world_size
            else:
                current_samples = input_ids.size(0) * world_size
                assert input_ids.size() == token_type_ids.size()
                assert input_ids.size() == attention_mask.size()
                assert input_ids.size() == labels.size()
                assert next_sentence_labels.dim() == 1
                assert input_ids.size(0) == next_sentence_labels.size(0)
                seq_lens = get_batch_seq_lens(attention_mask)
                seq_len_hist.update_with_tensor(seq_lens)
                (
                    min_lens[epoch - args.start_epoch, i],
                    max_lens[epoch - args.start_epoch, i],
                ) = seq_lens.min(), seq_lens.max()
                batch_sizes[epoch - args.start_epoch, i] = input_ids.size(0)
                padded_lens[epoch - args.start_epoch, i] = input_ids.size(1)
                padded_zero_hist.update_with_tensor(
                    input_ids.size(1) - seq_lens)

            total_samples += current_samples
            current_throughput = current_samples / elapsed
            if (i + 1) % args.log_freq == 0 and get_rank() == 0:
                avg_throughput = total_samples / meter.sum
                print('avg_throughput={}, avg_latency={} ms, '
                      'min_latency={} ms, max_latency={} ms, '
                      'current_throughput={}, current_latency={} ms'.format(
                          avg_throughput,
                          meter.avg * 1000,
                          meter.min * 1000,
                          meter.max * 1000,
                          current_throughput,
                          elapsed * 1000,
                      ))
                if args.debug:
                    print('len(data[0])={}'.format(len(data[0])))
                    print('sample=({} <SEP> {} - {})'.format(
                        data[0][0],
                        data[1][0],
                        data[2][0],
                    ))
                else:
                    print("Min length={} Max length={} Diff={}".format(
                        min_lens[epoch - args.start_epoch, i],
                        max_lens[epoch - args.start_epoch, i],
                        max_lens[epoch - args.start_epoch, i] -
                        min_lens[epoch - args.start_epoch, i],
                    ))
                    print('input_ids.size()={}'.format(input_ids.size()))
                    print('input_ids[0]={}'.format(input_ids[0]))
                    print('convert_ids_to_tokens(input_ids[0])={}'.format(
                        test_tokenizer.convert_ids_to_tokens(
                            input_ids[0].tolist())))
                    print('token_type_ids[0]={}'.format(token_type_ids[0]))
                    print('attention_mask[0]={}'.format(attention_mask[0]))
                    print('labels[0]={}'.format(labels[0]))
                    print('next_sentence_labels[0]={}'.format(
                        next_sentence_labels[0]))
                    mask = labels[0] != args.ignore_index
                    input_ids[0, mask] = labels[0, mask]
                    print('original sequence={}'.format(
                        test_tokenizer.convert_ids_to_tokens(
                            input_ids[0].tolist())))
            barrier()
            batch_timer_start = time.time()
        epoch_timer_stop = time.time()
        epoch_elapsed = epoch_timer_stop - epoch_timer_start
        if args.local_rank == 0:
            avg_throughput = total_samples / meter.sum
            print('epoch={}, epoch_elapsed={}, avg_throughput={}, '
                  'total_samples={}'.format(
                      epoch,
                      epoch_elapsed,
                      avg_throughput,
                      total_samples,
                  ))
        assert meter.iters == min(len(loader), args.iters_per_epoch)
        meter.reset()

    if args.seq_len_dir is not None:
        # Save the sequence lengths to file
        np.savez_compressed(
            os.path.join(args.seq_len_dir, 'lens_{}.npz'.format(get_rank())),
            min_lens=min_lens,
            max_lens=max_lens,
            batch_sizes=batch_sizes,
            padded_lens=padded_lens,
            seq_len_hist=seq_len_hist.hist,
            padded_zero_hist=padded_zero_hist.hist,
        )