예제 #1
0
def optimize_graph(params: Params):

    config = tf.ConfigProto(device_count={'GPU': 0}, allow_soft_placement=True)

    init_checkpoint = params.ckpt_dir

    tf.logging.info('build graph...')
    # input placeholders, not sure if they are friendly to XLA
    input_ids = tf.placeholder(tf.int32, (None, params.max_seq_len),
                               'input_ids')
    input_mask = tf.placeholder(tf.int32, (None, params.max_seq_len),
                                'input_mask')
    input_type_ids = tf.placeholder(tf.int32, (None, params.max_seq_len),
                                    'segment_ids')

    jit_scope = tf.contrib.compiler.jit.experimental_jit_scope

    with jit_scope():
        features = {}
        features['input_ids'] = input_ids
        features['input_mask'] = input_mask
        features['segment_ids'] = input_type_ids
        model = BertMultiTask(params)
        hidden_feature = model.body(features, tf.estimator.ModeKeys.PREDICT)
        pred = model.top(features, hidden_feature,
                         tf.estimator.ModeKeys.PREDICT)

        output_tensors = [pred[k] for k in pred]

        tvars = tf.trainable_variables()

        (assignment_map, initialized_variable_names
         ) = modeling.get_assignment_map_from_checkpoint(
             tvars, init_checkpoint)

        tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tmp_g = tf.get_default_graph().as_graph_def()

    with tf.Session(config=config) as sess:
        tf.logging.info('load parameters from checkpoint...')
        sess.run(tf.global_variables_initializer())
        tf.logging.info('freeze...')
        tmp_g = tf.graph_util.convert_variables_to_constants(
            sess, tmp_g, [n.name[:-2] for n in output_tensors])
    tmp_file = os.path.join(params.ckpt_dir, 'export_model')
    tf.logging.info('write graph to a tmp file: %s' % tmp_file)
    with tf.gfile.GFile(tmp_file, 'wb') as f:
        f.write(tmp_g.SerializeToString())
    return tmp_file
예제 #2
0
def train_problem(params, problem, gpu=4, base='baseline'):
    tf.keras.backend.clear_session()

    if not os.path.exists('tmp'):
        os.mkdir('tmp')

    base = os.path.join('tmp', base)
    params.assign_problem(problem, gpu=int(gpu), base_dir=base)

    create_path(params.ckpt_dir)

    tf.logging.info('Checkpoint dir: %s' % params.ckpt_dir)
    time.sleep(3)

    model = BertMultiTask(params=params)
    model_fn = model.get_model_fn(warm_start=False)

    dist_trategy = tf.contrib.distribute.MirroredStrategy(
        num_gpus=int(gpu),
        cross_tower_ops=tf.contrib.distribute.AllReduceCrossTowerOps(
            'nccl', num_packs=int(gpu)))

    run_config = tf.estimator.RunConfig(
        train_distribute=dist_trategy,
        eval_distribute=dist_trategy,
        log_step_count_steps=params.log_every_n_steps)

    # ws = make_warm_start_setting(params)

    estimator = Estimator(model_fn,
                          model_dir=params.ckpt_dir,
                          params=params,
                          config=run_config)
    train_hook = RestoreCheckpointHook(params)

    def train_input_fn():
        return train_eval_input_fn(params)

    estimator.train(train_input_fn,
                    max_steps=params.train_steps,
                    hooks=[train_hook])

    return estimator
예제 #3
0
def main(_):

    if not os.path.exists('tmp'):
        os.mkdir('tmp')

    if FLAGS.model_dir:
        base_dir, dir_name = os.path.split(FLAGS.model_dir)
    else:
        base_dir, dir_name = None, None

    params = Params()
    params.assign_problem(FLAGS.problem,
                          gpu=int(FLAGS.gpu),
                          base_dir=base_dir,
                          dir_name=dir_name)

    tf.logging.info('Checkpoint dir: %s' % params.ckpt_dir)
    time.sleep(3)

    model = BertMultiTask(params=params)
    model_fn = model.get_model_fn(warm_start=False)

    dist_trategy = tf.contrib.distribute.MirroredStrategy(
        num_gpus=int(FLAGS.gpu),
        cross_tower_ops=tf.contrib.distribute.AllReduceCrossTowerOps(
            'nccl', num_packs=int(FLAGS.gpu)))

    run_config = tf.estimator.RunConfig(
        train_distribute=dist_trategy,
        eval_distribute=dist_trategy,
        log_step_count_steps=params.log_every_n_steps)

    # ws = make_warm_start_setting(params)

    estimator = Estimator(model_fn,
                          model_dir=params.ckpt_dir,
                          params=params,
                          config=run_config)

    if FLAGS.schedule == 'train':
        train_hook = RestoreCheckpointHook(params)

        def train_input_fn():
            return train_eval_input_fn(params)

        estimator.train(train_input_fn,
                        max_steps=params.train_steps,
                        hooks=[train_hook])

        def input_fn():
            return train_eval_input_fn(params, mode='eval')

        estimator.evaluate(input_fn=input_fn)

    elif FLAGS.schedule == 'eval':

        evaluate_func = getattr(metrics, FLAGS.eval_scheme + '_evaluate')
        print(evaluate_func(FLAGS.problem, estimator, params))

    elif FLAGS.schedule == 'predict':

        def input_fn():
            return predict_input_fn([
                '''兰心餐厅\n作为一个无辣不欢的妹子,对上海菜的偏清淡偏甜真的是各种吃不惯。
            每次出门和闺蜜越饭局都是避开本帮菜。后来听很多朋友说上海有几家特别正宗味道做
            的很好的餐厅于是这周末和闺蜜们准备一起去尝一尝正宗的本帮菜。\n进贤路是我在上
            海比较喜欢的一条街啦,这家餐厅就开在这条路上。已经开了三十多年的老餐厅了,地
            方很小,就五六张桌子。但是翻桌率比较快。二楼之前的居民间也改成了餐厅,但是在
            上海的名气却非常大。烧的就是家常菜,普通到和家里烧的一样,生意非常好,外面排
            队的比里面吃的人还要多。'''
            ],
                                    params,
                                    mode='predict')

        pred = estimator.predict(input_fn=input_fn)
        for p in pred:
            print(p)
예제 #4
0
def optimize_graph(params):

    config = tf.ConfigProto(device_count={'GPU': 0}, allow_soft_placement=True)

    init_checkpoint = params.ckpt_dir

    tf.logging.info('build graph...')
    # input placeholders, not sure if they are friendly to XLA
    input_ids = tf.placeholder(tf.int32, (None, params.max_seq_len),
                               'input_ids')
    input_mask = tf.placeholder(tf.int32, (None, params.max_seq_len),
                                'input_mask')
    input_type_ids = tf.placeholder(tf.int32, (None, params.max_seq_len),
                                    'segment_ids')

    jit_scope = tf.contrib.compiler.jit.experimental_jit_scope

    with jit_scope():
        features = {}
        features['input_ids'] = input_ids
        features['input_mask'] = input_mask
        features['segment_ids'] = input_type_ids
        model = BertMultiTask(params)
        hidden_feature = model.body(features, tf.estimator.ModeKeys.PREDICT)
        pred = model.top(features, hidden_feature,
                         tf.estimator.ModeKeys.PREDICT)

        output_tensors = [pred[k] for k in pred]

        tvars = tf.trainable_variables()

        (assignment_map, initialized_variable_names
         ) = modeling.get_assignment_map_from_checkpoint(
             tvars, init_checkpoint)

        tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tmp_g = tf.get_default_graph().as_graph_def()

    input_node_names = ['input_ids', 'input_mask', 'segment_ids']
    output_node_names = [
        '%s_top/%s_predict' %
        (params.share_top[problem], params.share_top[problem])
        for problem in params.problem_list
    ]

    transforms = [
        'remove_nodes(op=Identity)',
        'fold_constants(ignore_errors=true)',
        'fold_batch_norms',
        # 'quantize_weights',
        # 'quantize_nodes',
        'merge_duplicate_nodes',
        'strip_unused_nodes',
        'sort_by_execution_order'
    ]

    with tf.Session(config=config) as sess:
        tf.logging.info('load parameters from checkpoint...')
        sess.run(tf.global_variables_initializer())
        tf.logging.info('freeze...')
        tmp_g = tf.graph_util.convert_variables_to_constants(
            sess, tmp_g, [n.name[:-2] for n in output_tensors])
        tmp_g = TransformGraph(tmp_g, input_node_names, output_node_names,
                               transforms)
    tmp_file = os.path.join(params.ckpt_dir, 'export_model')
    tf.logging.info('write graph to a tmp file: %s' % tmp_file)
    with tf.gfile.GFile(tmp_file, 'wb') as f:
        f.write(tmp_g.SerializeToString())
    return tmp_file