示例#1
0
def create_model(pyreader_name, ernie_config, task_group):
    """create_model"""
    src_ids = fluid.layers.data(name='src_ids',
            shape=[-1, args.max_seq_len, 1], dtype='int64')
    pos_ids = fluid.layers.data(name='pos_ids',
            shape=[-1, args.max_seq_len, 1], dtype='int64')
    sent_ids= fluid.layers.data(name='sent_ids',
            shape=[-1, args.max_seq_len, 1], dtype='int64')
    task_ids= fluid.layers.data(name='task_ids',
            shape=[-1, args.max_seq_len, 1], dtype='int64')
    input_mask = fluid.layers.data(name='input_mask',
            shape=[-1, args.max_seq_len, args.max_seq_len], dtype='float32')
    mask_label = fluid.layers.data(name='mask_label',
            shape=[-1, 1], dtype='int64')
    mask_pos = fluid.layers.data(name='mask_pos',
            shape=[-1, 1], dtype='int64')
    lm_weight = fluid.layers.data(name='lm_weight',
            shape=[1], dtype='float32', append_batch_size=False)
    batch_mask = fluid.layers.data(name='batch_mask',
            shape=[-1, 1], dtype='float32')
    loss_mask = fluid.layers.data(name="loss_mask", 
             shape=[-1, 1], dtype='float32')
    gather_idx = fluid.layers.data(name="gather_idx", 
             shape=[-1, 1], dtype='int64')

    
    task_params_all = []
    for index, task in enumerate(task_group):
        name_label = 'task_label_' + str(index)
        name_weight = 'task_weight_' + str(index)
        task_label = fluid.layers.data(name=name_label,
            shape=[-1, 1], dtype='int64')
        task_weight = fluid.layers.data(name=name_weight,
            shape=[1], dtype='float32', append_batch_size=False)
        task_params_all.extend([task_label, task_weight])

    fluid.reader.keep_data_loader_order(False)
    feed_list = [src_ids, pos_ids, sent_ids, task_ids, input_mask, \
                  mask_label, mask_pos, lm_weight, batch_mask, loss_mask, gather_idx] + task_params_all
    pyreader = fluid.io.DataLoader.from_generator(
            feed_list=feed_list,
            capacity=70, iterable=False)

    ernie = ErnieModel(
        src_ids=src_ids,
        position_ids=pos_ids,
        sentence_ids=sent_ids,
        task_ids=task_ids,
        input_mask=input_mask,
        config=ernie_config,
        weight_sharing=args.weight_sharing,
        use_fp16=args.use_amp)

    mask_lm_loss = ernie.get_lm_output(mask_label, mask_pos)
    checkpoints = ernie.get_checkpoints()
    
    total_loss = mask_lm_loss * lm_weight
    graph_vars = [mask_lm_loss, lm_weight]
    index = 11
    total_constract_loss = 0
    for task in task_group:
        task_labels = feed_list[index]
        task_weight = feed_list[index + 1]
        task_loss, task_acc = ernie.get_task_output(task, task_labels, gather_idx)
        total_loss += task_loss * task_weight * task["loss_weight"]
        if task["constart"]:
            contract_loss = ernie.get_contrastive_loss(batch_mask, loss_mask)
            total_loss += contract_loss * task_weight
            total_constract_loss += contract_loss * task_weight
        graph_vars.extend([task_acc, task_weight])
        index += 2
    
    graph_vars.append(total_constract_loss)
    graph_vars.append(total_loss)
    for var in graph_vars:
        var.persistable = True

    fetch_vars = {"graph_vars": graph_vars,
                  "checkpoints": checkpoints}

    return pyreader, fetch_vars
示例#2
0
def create_model(pyreader_name, ernie_config, task_group):
    """create_model"""
    ## get input
    shapes = [[bsz, args.max_seq_len, 1], [bsz, args.max_seq_len, 1],
              [bsz, args.max_seq_len, 1], [bsz, args.max_seq_len, 1],
              [bsz, args.max_seq_len, 1], [bsz, 1], [bsz, 1], [1], [bsz, 1],
              [bsz, 1], [bsz, 1]]
    names = [
        "src_ids", "pos_ids", "sent_ids", "task_ids", "input_mask",
        "mask_label", "mask_pos", "lm_weight", "batch_mask", "loss_mask",
        "gather_idx"
    ]
    dtypes = [
        "int64", "int64", "int64", "int64", "float32", "int64", "int64",
        "float32", "float32", "float32", "int64"
    ]
    cnt_general_input = len(shapes)

    for index, task in enumerate(task_group):
        shapes.extend([[bsz, 1], [1]])
        names.extend(['task_label_' + str(index), 'task_weight_' + str(index)])
        dtypes.extend(["int64", "float32"])

    assert len(shapes) == len(names) == len(
        dtypes), "The three fields must have same size"
    inputs = []
    for i in range(len(shapes)):
        inputs.append(
            fluid.layers.data(name=names[i],
                              shape=shapes[i],
                              dtype=dtypes[i],
                              append_batch_size=False))

    general_data, task_params = inputs[:cnt_general_input], inputs[
        cnt_general_input:]
    src_ids, pos_ids, sent_ids, task_ids, input_mask, \
                  mask_label, mask_pos, lm_weight, batch_mask, loss_mask, gather_idx = general_data

    ## build graph
    ernie = ErnieModel(src_ids=src_ids,
                       position_ids=pos_ids,
                       sentence_ids=sent_ids,
                       task_ids=task_ids,
                       input_mask=input_mask,
                       config=ernie_config,
                       weight_sharing=args.weight_sharing,
                       use_fp16=args.use_amp)

    mask_lm_loss = ernie.get_lm_output(mask_label, mask_pos)
    checkpoints = ernie.get_checkpoints()
    total_loss = mask_lm_loss * lm_weight
    graph_vars = [mask_lm_loss, lm_weight]

    index = 0
    total_constract_loss = 0
    for task in task_group:
        task_labels = task_params[index]
        task_weight = task_params[index + 1]
        task_loss, task_acc = ernie.get_task_output(task, task_labels,
                                                    gather_idx)
        total_loss += task_loss * task_weight * task["loss_weight"]
        if task["constart"]:
            contract_loss = ernie.get_contrastive_loss(batch_mask, loss_mask)
            total_loss += contract_loss * task_weight
            total_constract_loss += contract_loss * task_weight
        graph_vars.extend([task_acc, task_weight])
        index += 2

    ## build output
    graph_vars.append(total_constract_loss)
    graph_vars.append(total_loss)
    #for var in graph_vars:
    #    var.persistable = True

    fetch_vars = {"graph_vars": graph_vars, "checkpoints": checkpoints}

    return fetch_vars, names