Beispiel #1
0
def main_multi_task(args):
    from argparse import ArgumentParser
    parser = ArgumentParser()

    # parser.add_argument("--tokenizer", type=str, help="where to load vocabulary")
    parser.add_argument("--data", type=str)
    parser.add_argument("--out", type=str, help="output path")
    parser.add_argument("--prefix", type=str, default="train")
    parser.add_argument("--workers", type=int, default=6)
    args = parser.parse_args(args)

    tokenizer = BertWordPieceTokenizer("bert-base-chinese",
                                       cache_dir="temp_cache_dir")

    data_bin = os.path.join(args.out, "{}-CLM.bin".format(args.prefix))
    data_idx = os.path.join(args.out, "{}-CLM.idx".format(args.prefix))
    data_ds = indexed_dataset.IndexedDatasetBuilder(data_bin)

    def comsume(worker_result):
        for ids in worker_result:
            data_ds.add_item(torch.IntTensor(ids))

    pool = Pool(processes=args.workers)
    worker_result = []

    for i in range(args.workers):
        w = pool.apply_async(read_split,
                             (args.data, tokenizer, i, args.workers, 0, 10),
                             callback=comsume)
        worker_result.append(w)
    pool.close()
    pool.join()

    data_ds.finalize(data_idx)
    print("| write data into {}".format(args.out))
Beispiel #2
0
def multi_task_loader(args):
    tokenizer = BertWordPieceTokenizer("bert-base-chinese", cache_dir="temp_cache_dir")
    datapath = args.data
    train_prefix = args.train_prefix
    valid_prefix = args.valid_prefix
    train_data = os.path.join(datapath, train_prefix)
    valid_datas = [os.path.join(datapath, prefix) for prefix in valid_prefix.split(",")]

    train_data = CLMTaskDataset(train_data, tokenizer, args.train_batch, args.max_tokens, world_size=args.world_size,max_lens=args.max_lens, no_cache=args.no_cache, use_cls_special=args.use_cls_special)
    print("| Load train dataset :{}".format(len(train_data)))
    train_simpler = FuseSampler(train_data, args.world_size, args.rank)
    train = torch.utils.data.DataLoader(train_data, batch_sampler=train_simpler, collate_fn=train_data.collate,
                                        num_workers=args.num_workers,
                                        pin_memory=True)
    print("| After train batch size {}".format(len(train)))
    valids = []
    for data in valid_datas:
        d = CLMTaskDataset(data, tokenizer, args.valid_batch, args.max_tokens, world_size=args.world_size, max_lens=args.max_lens, no_cache=args.no_cache, use_cls_special=args.use_cls_special)
        print("| Load valid dataset :{}".format(len(d)))
        simpler = FuseSampler(d, args.world_size, args.rank)
        d = torch.utils.data.DataLoader(d, batch_sampler=simpler, collate_fn=d.collate,
                                        num_workers=args.num_workers,
                                        pin_memory=True)
        print("| After valid batch size {}".format(len(d)))
        valids.append(d)
    return (train, valids), tokenizer
Beispiel #3
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--model-config",
                        type=str,
                        default="openai-gpt",
                        help="Path, url or short name of the model")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument("--outlens", type=int, default=30)
    parser.add_argument("--beam", type=int, default=1)
    parser.add_argument("--checkpoints", type=str)
    parser.add_argument("--data", type=str, default="file")

    args = parser.parse_args()
    args.load_model = True

    model = BertModel(None, args)
    state_dict = convert_model(torch.load(args.checkpoints)['sd'])
    model.load_state_dict(state_dict)
    model.to(args.device)
    tokenizer = BertWordPieceTokenizer("bert-base-chinese",
                                       cache_dir="temp_cache_dir")
    generate(model,
             tokenizer,
             args.device,
             args.data,
             sample=True,
             top_k=5,
             beam_size=6,
             outlens=30)
Beispiel #4
0
def model_init(app):
    ArgsSet = type('ArgsSet',(object,),{})
    client = ArgsSet()
    parser = ArgumentParser()
    parser.add_argument("--model-config", type=str, default="openai-gpt",
                        help="Path, url or short name of the model")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available()
                        else "cpu", help="Device (cuda or cpu)")
    parser.add_argument("--outlens", type=int, default=30)
    parser.add_argument("--beam", type=int, default=1)
    parser.add_argument("--gpt-checkpoints", type=str)
    parser.add_argument("--port", type=int, default=8866)

    args = parser.parse_args()
    args.load_model = True
    args.fp32_embedding = False
    args.fp32_layernorm = False
    args.fp32_tokentypes = False
    args.layernorm_epsilon = 1e-12

    gpt = BertModel(None, args)
    state_dict = convert_model(torch.load(args.gpt_checkpoints)['sd'])
    gpt.load_state_dict(state_dict)
    gpt.to(args.device)
    gpt.eval()
    tokenizer = BertWordPieceTokenizer("bert-base-chinese", cache_dir="temp_cache_dir")
    print(" Load model from {}".format(args.gpt_checkpoints))

    client.tokenizer = tokenizer
    client.gpt =gpt
    client.gpt_beam = SequenceGenerator(gpt, tokenizer, beam_size=args.beam, max_lens=args.outlens)
    client.device = args.device
    client.port = args.port
    client.generator = sample_sequence

    return client
Beispiel #5
0
def make_loaders_2(args):
    tokenizer = BertWordPieceTokenizer("bert-base-chinese", cache_dir="temp_cache_dir")
    if args.no_nsp:
        train, valid_dataset = FuseDataset.load_dataset_no_nsp(tokenizer, args)
    else:
        train, valid_dataset = FuseDataset.load_dataset(tokenizer, args)
    print("| Load train dataset :{}".format(len(train)))
    for d in valid_dataset:
        print("| Load valid dataset :{}".format(len(d)))
    train_simpler = FuseSampler(train, args.world_size, args.rank)
    train = torch.utils.data.DataLoader(train, batch_sampler=train_simpler, collate_fn=train.collate,
                                        num_workers=args.num_workers,
                                        pin_memory=True)
    print("| After train batch size {}".format(len(train)))
    valids = []
    for data in valid_dataset:
        s = FuseSampler(data, args.world_size, args.rank)
        l = torch.utils.data.DataLoader(data, batch_sampler=s, collate_fn=data.collate,
                                        num_workers=args.num_workers,
                                        pin_memory=True)
        print("| After valid batch size {}".format(len(l)))
        valids.append(l)

    return (train, valids), tokenizer
Beispiel #6
0
def model_init(app):
    ArgsSet = type('ArgsSet', (object, ), {})
    client = ArgsSet()
    parser = ArgumentParser()
    parser.add_argument("--model-config",
                        type=str,
                        default="openai-gpt",
                        help="Path, url or short name of the model")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument("--outlens", type=int, default=30)
    parser.add_argument("--beam", type=int, default=1)
    parser.add_argument("--fuse-checkpoints", type=str)
    parser.add_argument("--gpt-checkpoints", type=str)
    parser.add_argument("--qa-style-checkpoints", type=str)
    parser.add_argument("--multi-task", type=str)
    parser.add_argument("--split-sentence-with-task-embedding-checkpoints",
                        type=str)
    parser.add_argument("--special-cls-checkpoints", type=str)

    parser.add_argument("--port", type=int, default=8866)

    args = parser.parse_args()
    args.load_model = True
    args.fp32_embedding = False
    args.fp32_layernorm = False
    args.fp32_tokentypes = False
    args.layernorm_epsilon = 1e-12

    fuse_model = BertModel(None, args)
    state_dict = convert_model(torch.load(args.fuse_checkpoints)['sd'])
    fuse_model.load_state_dict(state_dict)
    fuse_model.to(args.device)
    fuse_model.eval()
    print("| Load model from {}".format(args.fuse_checkpoints))

    gpt = BertModel(None, args)
    state_dict = convert_model(torch.load(args.gpt_checkpoints)['sd'])
    gpt.load_state_dict(state_dict)
    gpt.to(args.device)
    gpt.eval()
    tokenizer = BertWordPieceTokenizer("bert-base-chinese",
                                       cache_dir="temp_cache_dir")
    print(" Load model from {}".format(args.gpt_checkpoints))

    # Load bert checkpoints
    args.load_model = False
    args.fp32_embedding = False
    args.fp32_layernorm = False
    args.fp32_tokentypes = False
    args.layernorm_epsilon = 1e-12
    bert = BertModel(None, args)
    bert.to(args.device)
    bert.eval()

    client.tokenizer = tokenizer
    client.fuse_model = fuse_model
    client.fuse_beam = SequenceGenerator(fuse_model,
                                         tokenizer,
                                         beam_size=args.beam,
                                         max_lens=args.outlens)
    client.gpt = gpt
    client.gpt_beam = SequenceGenerator(gpt,
                                        tokenizer,
                                        beam_size=args.beam,
                                        max_lens=args.outlens)
    client.bert = bert
    client.device = args.device
    client.port = args.port
    client.generator = sample_sequence

    # multi task model

    multi_task = BertModel(None, args)
    state_dict = convert_model(torch.load(args.multi_task)['sd'])
    print("| Load model from {}".format(args.multi_task))
    multi_task.load_state_dict(state_dict)
    multi_task.to(args.device)
    multi_task.eval()
    client.multi_task_model = multi_task
    client.multi_task_beam = SequenceGenerator(multi_task,
                                               tokenizer,
                                               beam_size=args.beam,
                                               max_lens=args.outlens)

    # qa style model
    qa_style = BertModel(None, args)
    state_dict = convert_model(torch.load(args.qa_style_checkpoints)['sd'])
    qa_style.load_state_dict(state_dict)
    qa_style.to(args.device)
    qa_style.eval()
    print(" Load model from {}".format(args.qa_style_checkpoints))
    client.qa_task_model = qa_style

    # special cls tokens
    special_cls_model = BertModel(None, args)
    special_cls_model.eval()
    state_dict = convert_model(torch.load(args.special_cls_checkpoints)['sd'])
    special_cls_model.load_state_dict(state_dict)
    special_cls_model.to(args.device)
    special_cls_model.eval()
    print(" Load model from {}".format(args.special_cls_checkpoints))
    client.special_cls_model = special_cls_model
    client.special_beam = SequenceGenerator(special_cls_model,
                                            tokenizer,
                                            beam_size=args.beam,
                                            max_lens=args.outlens)

    # split sentence model with task embedding
    split_sentence_model = BertModel(None, args)
    split_sentence_model.eval()
    state_dict = convert_model(
        torch.load(args.split_sentence_with_task_embedding_checkpoints)['sd'])
    split_sentence_model.load_state_dict(state_dict)
    split_sentence_model.to(args.device)
    split_sentence_model.eval()
    print(" Load model from {}".format(
        args.split_sentence_with_task_embedding_checkpoints))
    client.split_sentence_model = split_sentence_model
    client.split_sentence_beam = SequenceGenerator(split_sentence_model,
                                                   tokenizer,
                                                   beam_size=args.beam,
                                                   max_lens=args.outlens)

    return client