Exemplo n.º 1
0
def create_model(pyreader_name, ernie_config, task_group):
    """create_model"""
    ## get input
    shapes = [[bsz, args.max_seq_len, 1], [bsz, args.max_seq_len, 1],
              [bsz, args.max_seq_len, 1], [bsz, args.max_seq_len, 1],
              [bsz, args.max_seq_len, 1], [bsz, 1], [bsz, 1], [1], [bsz, 1],
              [bsz, 1], [bsz, 1]]
    names = [
        "src_ids", "pos_ids", "sent_ids", "task_ids", "input_mask",
        "mask_label", "mask_pos", "lm_weight", "batch_mask", "loss_mask",
        "gather_idx"
    ]
    dtypes = [
        "int64", "int64", "int64", "int64", "float32", "int64", "int64",
        "float32", "float32", "float32", "int64"
    ]
    cnt_general_input = len(shapes)

    for index, task in enumerate(task_group):
        shapes.extend([[bsz, 1], [1]])
        names.extend(['task_label_' + str(index), 'task_weight_' + str(index)])
        dtypes.extend(["int64", "float32"])

    assert len(shapes) == len(names) == len(
        dtypes), "The three fields must have same size"
    inputs = []
    for i in range(len(shapes)):
        inputs.append(
            fluid.layers.data(name=names[i],
                              shape=shapes[i],
                              dtype=dtypes[i],
                              append_batch_size=False))

    general_data, task_params = inputs[:cnt_general_input], inputs[
        cnt_general_input:]
    src_ids, pos_ids, sent_ids, task_ids, input_mask, \
                  mask_label, mask_pos, lm_weight, batch_mask, loss_mask, gather_idx = general_data

    ## build graph
    ernie = ErnieModel(src_ids=src_ids,
                       position_ids=pos_ids,
                       sentence_ids=sent_ids,
                       task_ids=task_ids,
                       input_mask=input_mask,
                       config=ernie_config,
                       weight_sharing=args.weight_sharing,
                       use_fp16=args.use_amp)

    mask_lm_loss = ernie.get_lm_output(mask_label, mask_pos)
    checkpoints = ernie.get_checkpoints()
    total_loss = mask_lm_loss * lm_weight
    graph_vars = [mask_lm_loss, lm_weight]

    index = 0
    total_constract_loss = 0
    for task in task_group:
        task_labels = task_params[index]
        task_weight = task_params[index + 1]
        task_loss, task_acc = ernie.get_task_output(task, task_labels,
                                                    gather_idx)
        total_loss += task_loss * task_weight * task["loss_weight"]
        if task["constart"]:
            contract_loss = ernie.get_contrastive_loss(batch_mask, loss_mask)
            total_loss += contract_loss * task_weight
            total_constract_loss += contract_loss * task_weight
        graph_vars.extend([task_acc, task_weight])
        index += 2

    ## build output
    graph_vars.append(total_constract_loss)
    graph_vars.append(total_loss)
    #for var in graph_vars:
    #    var.persistable = True

    fetch_vars = {"graph_vars": graph_vars, "checkpoints": checkpoints}

    return fetch_vars, names
Exemplo n.º 2
0
def create_model(pyreader_name, ernie_config, task_group):
    """create_model"""
    src_ids = fluid.layers.data(name='src_ids',
            shape=[-1, args.max_seq_len, 1], dtype='int64')
    pos_ids = fluid.layers.data(name='pos_ids',
            shape=[-1, args.max_seq_len, 1], dtype='int64')
    sent_ids= fluid.layers.data(name='sent_ids',
            shape=[-1, args.max_seq_len, 1], dtype='int64')
    task_ids= fluid.layers.data(name='task_ids',
            shape=[-1, args.max_seq_len, 1], dtype='int64')
    input_mask = fluid.layers.data(name='input_mask',
            shape=[-1, args.max_seq_len, args.max_seq_len], dtype='float32')
    mask_label = fluid.layers.data(name='mask_label',
            shape=[-1, 1], dtype='int64')
    mask_pos = fluid.layers.data(name='mask_pos',
            shape=[-1, 1], dtype='int64')
    lm_weight = fluid.layers.data(name='lm_weight',
            shape=[1], dtype='float32', append_batch_size=False)
    batch_mask = fluid.layers.data(name='batch_mask',
            shape=[-1, 1], dtype='float32')
    loss_mask = fluid.layers.data(name="loss_mask", 
             shape=[-1, 1], dtype='float32')
    gather_idx = fluid.layers.data(name="gather_idx", 
             shape=[-1, 1], dtype='int64')

    
    task_params_all = []
    for index, task in enumerate(task_group):
        name_label = 'task_label_' + str(index)
        name_weight = 'task_weight_' + str(index)
        task_label = fluid.layers.data(name=name_label,
            shape=[-1, 1], dtype='int64')
        task_weight = fluid.layers.data(name=name_weight,
            shape=[1], dtype='float32', append_batch_size=False)
        task_params_all.extend([task_label, task_weight])

    fluid.reader.keep_data_loader_order(False)
    feed_list = [src_ids, pos_ids, sent_ids, task_ids, input_mask, \
                  mask_label, mask_pos, lm_weight, batch_mask, loss_mask, gather_idx] + task_params_all
    pyreader = fluid.io.DataLoader.from_generator(
            feed_list=feed_list,
            capacity=70, iterable=False)

    ernie = ErnieModel(
        src_ids=src_ids,
        position_ids=pos_ids,
        sentence_ids=sent_ids,
        task_ids=task_ids,
        input_mask=input_mask,
        config=ernie_config,
        weight_sharing=args.weight_sharing,
        use_fp16=args.use_amp)

    mask_lm_loss = ernie.get_lm_output(mask_label, mask_pos)
    checkpoints = ernie.get_checkpoints()
    
    total_loss = mask_lm_loss * lm_weight
    graph_vars = [mask_lm_loss, lm_weight]
    index = 11
    total_constract_loss = 0
    for task in task_group:
        task_labels = feed_list[index]
        task_weight = feed_list[index + 1]
        task_loss, task_acc = ernie.get_task_output(task, task_labels, gather_idx)
        total_loss += task_loss * task_weight * task["loss_weight"]
        if task["constart"]:
            contract_loss = ernie.get_contrastive_loss(batch_mask, loss_mask)
            total_loss += contract_loss * task_weight
            total_constract_loss += contract_loss * task_weight
        graph_vars.extend([task_acc, task_weight])
        index += 2
    
    graph_vars.append(total_constract_loss)
    graph_vars.append(total_loss)
    for var in graph_vars:
        var.persistable = True

    fetch_vars = {"graph_vars": graph_vars,
                  "checkpoints": checkpoints}

    return pyreader, fetch_vars
Exemplo n.º 3
0
def create_model(args, phase, micro_bsz, dp_sharding_rank, dp_worldsize, topo):
    if args.use_sop:
        from reader.pretraining_ds_ernie_full_sent import make_pretrain_dataset
    else:
        from reader.pretraining_ds_mlm import make_pretrain_dataset

    # mask_label, mask_pos for mlm, labels for sop
    if args.use_sop:
        input_fields = {
            'names':
            ['src_ids', 'sent_ids', 'mask_label', 'mask_pos', 'labels'],
            'shapes': [[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1],
                       [-1, 1], [-1, 1], [-1, 1]],
            'dtypes': ['int64', 'int64', 'int64', 'int64', 'int64'],
            'lod_levels': [0, 0, 0, 0, 0],
        }
    else:
        input_fields = {
            'names': ['src_ids', 'sent_ids', 'mask_label', 'mask_pos'],
            'shapes': [[-1, args.max_seq_len, 1], [-1, args.max_seq_len, 1],
                       [-1, 1], [-1, 1]],
            'dtypes': ['int64', 'int64', 'int64', 'int64'],
            'lod_levels': [0, 0, 0, 0],
        }

    with fluid.device_guard("gpu:0"):
        inputs = [
            fluid.data(name=input_fields['names'][i],
                       shape=input_fields['shapes'][i],
                       dtype=input_fields['dtypes'][i],
                       lod_level=input_fields['lod_levels'][i])
            for i in range(len(input_fields['names']))
        ]
    if args.use_sop:
        (src_ids, sent_ids, mask_label, mask_pos, labels) = inputs
    else:
        (src_ids, sent_ids, mask_label, mask_pos) = inputs
    train_file_list = glob.glob(args.data_dir + "/*")
    vocab = {}
    with open(args.vocab_file) as r:
        for line in r:
            lines = line.strip().split('\t')
            vocab[lines[0]] = int(lines[1])

    log.debug("========= worker: {} of {} ==========".format(
        dp_sharding_rank, dp_worldsize))

    data_reader = make_pretrain_dataset('pt', train_file_list, True, vocab,
                                        micro_bsz, len(vocab),
                                        args.max_seq_len, dp_sharding_rank,
                                        dp_worldsize)
    with fluid.device_guard("gpu:0"):
        data_loader = fluid.io.DataLoader.from_generator(feed_list=inputs,
                                                         capacity=70,
                                                         iterable=False)
    places = fluid.CUDAPlace(int(os.environ.get('FLAGS_selected_gpus', 0)))

    def data_gen():
        yield from data_reader

    data_loader.set_batch_generator(data_gen, places)

    ernie_config = ErnieConfig(args.ernie_config_file)._config_dict
    ernie_config["preln"] = args.preln

    weight_sharing = (topo.mp.size == 1 and topo.pp.size == 1
                      )  # pp mp should not do weight sharing
    with fluid.device_guard("gpu:0"):
        ernie = ErnieModel(src_ids,
                           sent_ids,
                           ernie_config,
                           weight_sharing=weight_sharing,
                           topo=topo)
    checkpoints = ernie._checkpoints
    checkpoints.pop(-1)

    with fluid.device_guard(f'gpu:{args.num_pp-1}'):
        mask_lm_loss, mean_mask_lm_loss = ernie.get_lm_output(
            mask_label, mask_pos)
        total_loss = mean_mask_lm_loss

        if args.use_sop:
            sop_acc, mean_sop_loss = ernie.get_next_sentence_output(labels)
            total_loss += mean_sop_loss

        if topo.pp.size > 1:
            mask_lm_loss.persistable = True
            mean_mask_lm_loss.persistable = True
            # checkpoints.extend([mask_lm_loss.name, mean_mask_lm_loss.name])
            if args.use_sop:
                mean_sop_loss.persistable = True
                sop_acc.persistable = True
                # checkpoints.extend([mean_sop_loss.name, sop_acc.name])
            total_loss.persistable = True
            # checkpoints.append(total_loss.name)

    if args.use_sop:
        graph_vars = {
            'data_loader': data_loader,
            'mask_lm_loss': mask_lm_loss,
            'mean_mask_lm_loss': mean_mask_lm_loss,
            'sop_loss': mean_sop_loss,
            'sop_acc': sop_acc,
            'total_loss': total_loss,
            'checkpoints': checkpoints
        }
    else:
        graph_vars = {
            'data_loader': data_loader,
            'mask_lm_loss': mask_lm_loss,
            'mean_mask_lm_loss': mean_mask_lm_loss,
            'total_loss': total_loss,
            'checkpoints': checkpoints,
        }
    return graph_vars