コード例 #1
0
def optimize_graph(params: Params):

    config = tf.ConfigProto(device_count={'GPU': 0}, allow_soft_placement=True)

    init_checkpoint = params.ckpt_dir

    tf.logging.info('build graph...')
    # input placeholders, not sure if they are friendly to XLA
    input_ids = tf.placeholder(tf.int32, (None, params.max_seq_len),
                               'input_ids')
    input_mask = tf.placeholder(tf.int32, (None, params.max_seq_len),
                                'input_mask')
    input_type_ids = tf.placeholder(tf.int32, (None, params.max_seq_len),
                                    'segment_ids')

    jit_scope = tf.contrib.compiler.jit.experimental_jit_scope

    with jit_scope():
        features = {}
        features['input_ids'] = input_ids
        features['input_mask'] = input_mask
        features['segment_ids'] = input_type_ids
        model = BertMultiTask(params)
        hidden_feature = model.body(features, tf.estimator.ModeKeys.PREDICT)
        pred = model.top(features, hidden_feature,
                         tf.estimator.ModeKeys.PREDICT)

        output_tensors = [pred[k] for k in pred]

        tvars = tf.trainable_variables()

        (assignment_map, initialized_variable_names
         ) = modeling.get_assignment_map_from_checkpoint(
             tvars, init_checkpoint)

        tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tmp_g = tf.get_default_graph().as_graph_def()

    with tf.Session(config=config) as sess:
        tf.logging.info('load parameters from checkpoint...')
        sess.run(tf.global_variables_initializer())
        tf.logging.info('freeze...')
        tmp_g = tf.graph_util.convert_variables_to_constants(
            sess, tmp_g, [n.name[:-2] for n in output_tensors])
    tmp_file = os.path.join(params.ckpt_dir, 'export_model')
    tf.logging.info('write graph to a tmp file: %s' % tmp_file)
    with tf.gfile.GFile(tmp_file, 'wb') as f:
        f.write(tmp_g.SerializeToString())
    return tmp_file
コード例 #2
0
def optimize_graph(params):

    config = tf.ConfigProto(device_count={'GPU': 0}, allow_soft_placement=True)

    init_checkpoint = params.ckpt_dir

    tf.logging.info('build graph...')
    # input placeholders, not sure if they are friendly to XLA
    input_ids = tf.placeholder(tf.int32, (None, params.max_seq_len),
                               'input_ids')
    input_mask = tf.placeholder(tf.int32, (None, params.max_seq_len),
                                'input_mask')
    input_type_ids = tf.placeholder(tf.int32, (None, params.max_seq_len),
                                    'segment_ids')

    jit_scope = tf.contrib.compiler.jit.experimental_jit_scope

    with jit_scope():
        features = {}
        features['input_ids'] = input_ids
        features['input_mask'] = input_mask
        features['segment_ids'] = input_type_ids
        model = BertMultiTask(params)
        hidden_feature = model.body(features, tf.estimator.ModeKeys.PREDICT)
        pred = model.top(features, hidden_feature,
                         tf.estimator.ModeKeys.PREDICT)

        output_tensors = [pred[k] for k in pred]

        tvars = tf.trainable_variables()

        (assignment_map, initialized_variable_names
         ) = modeling.get_assignment_map_from_checkpoint(
             tvars, init_checkpoint)

        tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tmp_g = tf.get_default_graph().as_graph_def()

    input_node_names = ['input_ids', 'input_mask', 'segment_ids']
    output_node_names = [
        '%s_top/%s_predict' %
        (params.share_top[problem], params.share_top[problem])
        for problem in params.problem_list
    ]

    transforms = [
        'remove_nodes(op=Identity)',
        'fold_constants(ignore_errors=true)',
        'fold_batch_norms',
        # 'quantize_weights',
        # 'quantize_nodes',
        'merge_duplicate_nodes',
        'strip_unused_nodes',
        'sort_by_execution_order'
    ]

    with tf.Session(config=config) as sess:
        tf.logging.info('load parameters from checkpoint...')
        sess.run(tf.global_variables_initializer())
        tf.logging.info('freeze...')
        tmp_g = tf.graph_util.convert_variables_to_constants(
            sess, tmp_g, [n.name[:-2] for n in output_tensors])
        tmp_g = TransformGraph(tmp_g, input_node_names, output_node_names,
                               transforms)
    tmp_file = os.path.join(params.ckpt_dir, 'export_model')
    tf.logging.info('write graph to a tmp file: %s' % tmp_file)
    with tf.gfile.GFile(tmp_file, 'wb') as f:
        f.write(tmp_g.SerializeToString())
    return tmp_file