def create_model(pyreader_name, ernie_config, task_group): """create_model""" src_ids = fluid.layers.data(name='src_ids', shape=[-1, args.max_seq_len, 1], dtype='int64') pos_ids = fluid.layers.data(name='pos_ids', shape=[-1, args.max_seq_len, 1], dtype='int64') sent_ids= fluid.layers.data(name='sent_ids', shape=[-1, args.max_seq_len, 1], dtype='int64') task_ids= fluid.layers.data(name='task_ids', shape=[-1, args.max_seq_len, 1], dtype='int64') input_mask = fluid.layers.data(name='input_mask', shape=[-1, args.max_seq_len, args.max_seq_len], dtype='float32') mask_label = fluid.layers.data(name='mask_label', shape=[-1, 1], dtype='int64') mask_pos = fluid.layers.data(name='mask_pos', shape=[-1, 1], dtype='int64') lm_weight = fluid.layers.data(name='lm_weight', shape=[1], dtype='float32', append_batch_size=False) batch_mask = fluid.layers.data(name='batch_mask', shape=[-1, 1], dtype='float32') loss_mask = fluid.layers.data(name="loss_mask", shape=[-1, 1], dtype='float32') gather_idx = fluid.layers.data(name="gather_idx", shape=[-1, 1], dtype='int64') task_params_all = [] for index, task in enumerate(task_group): name_label = 'task_label_' + str(index) name_weight = 'task_weight_' + str(index) task_label = fluid.layers.data(name=name_label, shape=[-1, 1], dtype='int64') task_weight = fluid.layers.data(name=name_weight, shape=[1], dtype='float32', append_batch_size=False) task_params_all.extend([task_label, task_weight]) fluid.reader.keep_data_loader_order(False) feed_list = [src_ids, pos_ids, sent_ids, task_ids, input_mask, \ mask_label, mask_pos, lm_weight, batch_mask, loss_mask, gather_idx] + task_params_all pyreader = fluid.io.DataLoader.from_generator( feed_list=feed_list, capacity=70, iterable=False) ernie = ErnieModel( src_ids=src_ids, position_ids=pos_ids, sentence_ids=sent_ids, task_ids=task_ids, input_mask=input_mask, config=ernie_config, weight_sharing=args.weight_sharing, use_fp16=args.use_amp) mask_lm_loss = ernie.get_lm_output(mask_label, mask_pos) checkpoints = ernie.get_checkpoints() total_loss = mask_lm_loss * lm_weight graph_vars = [mask_lm_loss, lm_weight] index = 11 total_constract_loss = 0 for task in task_group: task_labels = feed_list[index] task_weight = feed_list[index + 1] task_loss, task_acc = ernie.get_task_output(task, task_labels, gather_idx) total_loss += task_loss * task_weight * task["loss_weight"] if task["constart"]: contract_loss = ernie.get_contrastive_loss(batch_mask, loss_mask) total_loss += contract_loss * task_weight total_constract_loss += contract_loss * task_weight graph_vars.extend([task_acc, task_weight]) index += 2 graph_vars.append(total_constract_loss) graph_vars.append(total_loss) for var in graph_vars: var.persistable = True fetch_vars = {"graph_vars": graph_vars, "checkpoints": checkpoints} return pyreader, fetch_vars
def create_model(pyreader_name, ernie_config, task_group): """create_model""" ## get input shapes = [[bsz, args.max_seq_len, 1], [bsz, args.max_seq_len, 1], [bsz, args.max_seq_len, 1], [bsz, args.max_seq_len, 1], [bsz, args.max_seq_len, 1], [bsz, 1], [bsz, 1], [1], [bsz, 1], [bsz, 1], [bsz, 1]] names = [ "src_ids", "pos_ids", "sent_ids", "task_ids", "input_mask", "mask_label", "mask_pos", "lm_weight", "batch_mask", "loss_mask", "gather_idx" ] dtypes = [ "int64", "int64", "int64", "int64", "float32", "int64", "int64", "float32", "float32", "float32", "int64" ] cnt_general_input = len(shapes) for index, task in enumerate(task_group): shapes.extend([[bsz, 1], [1]]) names.extend(['task_label_' + str(index), 'task_weight_' + str(index)]) dtypes.extend(["int64", "float32"]) assert len(shapes) == len(names) == len( dtypes), "The three fields must have same size" inputs = [] for i in range(len(shapes)): inputs.append( fluid.layers.data(name=names[i], shape=shapes[i], dtype=dtypes[i], append_batch_size=False)) general_data, task_params = inputs[:cnt_general_input], inputs[ cnt_general_input:] src_ids, pos_ids, sent_ids, task_ids, input_mask, \ mask_label, mask_pos, lm_weight, batch_mask, loss_mask, gather_idx = general_data ## build graph ernie = ErnieModel(src_ids=src_ids, position_ids=pos_ids, sentence_ids=sent_ids, task_ids=task_ids, input_mask=input_mask, config=ernie_config, weight_sharing=args.weight_sharing, use_fp16=args.use_amp) mask_lm_loss = ernie.get_lm_output(mask_label, mask_pos) checkpoints = ernie.get_checkpoints() total_loss = mask_lm_loss * lm_weight graph_vars = [mask_lm_loss, lm_weight] index = 0 total_constract_loss = 0 for task in task_group: task_labels = task_params[index] task_weight = task_params[index + 1] task_loss, task_acc = ernie.get_task_output(task, task_labels, gather_idx) total_loss += task_loss * task_weight * task["loss_weight"] if task["constart"]: contract_loss = ernie.get_contrastive_loss(batch_mask, loss_mask) total_loss += contract_loss * task_weight total_constract_loss += contract_loss * task_weight graph_vars.extend([task_acc, task_weight]) index += 2 ## build output graph_vars.append(total_constract_loss) graph_vars.append(total_loss) #for var in graph_vars: # var.persistable = True fetch_vars = {"graph_vars": graph_vars, "checkpoints": checkpoints} return fetch_vars, names