Exemplo n.º 1
0
def convert(load_file, dest_file):
    from tensorflow.python.framework import meta_graph
    features, labels = dual_net.get_inference_input()
    dual_net.model_fn(features, labels, tf.estimator.ModeKeys.PREDICT)
    sess = tf.Session()

    # retrieve the global step as a python value
    ckpt = tf.train.load_checkpoint(load_file)
    global_step_value = ckpt.get_tensor('global_step')

    # restore all saved weights, except global_step
    meta_graph_def = meta_graph.read_meta_graph_file(load_file + '.meta')
    stored_var_names = set([
        n.name for n in meta_graph_def.graph_def.node if n.op == 'VariableV2'
    ])
    stored_var_names.remove('global_step')
    var_list = [
        v for v in tf.global_variables() if v.op.name in stored_var_names
    ]
    tf.train.Saver(var_list=var_list).restore(sess, load_file)

    # manually set the global step
    global_step_tensor = tf.train.get_or_create_global_step()
    assign_op = tf.assign(global_step_tensor, global_step_value)
    sess.run(assign_op)

    # export a new savedmodel that has the right global step type
    tf.train.Saver().save(sess, dest_file)
    sess.close()
    tf.reset_default_graph()
Exemplo n.º 2
0
def backfill():
    models = [m[1] for m in fsdb.get_models()]

    import dual_net
    import tensorflow as tf
    from tqdm import tqdm
    features, labels = dual_net.get_inference_input()
    dual_net.model_fn(features, labels, tf.estimator.ModeKeys.PREDICT,
                      dual_net.get_default_hyperparams())

    for model_name in tqdm(models):
        if model_name.endswith('-upgrade'):
            continue
        try:
            load_file = os.path.join(fsdb.models_dir(), model_name)
            dest_file = os.path.join(fsdb.models_dir(), model_name)
            main.convert(load_file, dest_file)
        except:
            print('failed on', model_name)
            continue
Exemplo n.º 3
0
def swa():
    path_base = fsdb.models_dir()
    model_names = [
        "000393-lincoln",
        "000390-indus",
        "000404-hannibal",
        "000447-hawke",
        "000426-grief",
        "000431-lion",
        "000428-invincible",
        "000303-olympus",
        "000291-superb",
        "000454-victorious",
    ]
    model_names = model_names[:FLAGS.count]

    model_paths = [os.path.join(path_base, m) for m in model_names]

    # construct the graph
    features, labels = dual_net.get_inference_input()
    dual_net.model_fn(features, labels, tf.estimator.ModeKeys.PREDICT)

    # restore all saved weights
    meta_graph_def = meta_graph.read_meta_graph_file(model_paths[0] + '.meta')
    stored_var_names = set([
        n.name for n in meta_graph_def.graph_def.node if n.op == 'VariableV2'
    ])

    var_list = [
        v for v in tf.global_variables() if v.op.name in stored_var_names
    ]
    var_list.sort(key=lambda v: v.op.name)

    print(stored_var_names)
    print(len(stored_var_names), len(var_list))

    sessions = [tf.Session() for _ in model_paths]
    saver = tf.train.Saver()
    for sess, model_path in zip(sessions, model_paths):
        saver.restore(sess, model_path)

    # Load all VariableV2s for each model.
    values = [sess.run(var_list) for sess in sessions]

    # Iterate over all variables average values from all models.
    all_assign = []
    for var, vals in zip(var_list, zip(*values)):
        print("{}x {}".format(len(vals), var))
        if var.name == "global_step:0":
            avg = vals[0]
            for val in vals:
                avg = tf.maximum(avg, val)
        else:
            avg = tf.add_n(vals) / len(vals)
            continue

        all_assign.append(tf.assign(var, avg))

    # Run all asign ops on an existing model (which has other ops and graph).
    sess = sessions[0]
    sess.run(all_assign)

    # Export a new saved model.
    ensure_dir_exists(FLAGS.data_dir)
    dest_path = os.path.join(FLAGS.data_dir, "swa-" + str(FLAGS.count))
    saver.save(sess, dest_path)
Exemplo n.º 4
0
def init_train(rank, tcomm, model_dir):
    """Train on examples and export the updated model weights."""
    # init hvd
    logging.info('hvd init at rank %d', rank)
    hvd.init(tcomm)

    #
    FLAGS.export_path = model_dir

    if rank == 0:
        logging.info('[ Train flags ] freeze              = %d', FLAGS.freeze)
        logging.info('[ Train flags ] window_size         = %d',
                     FLAGS.window_size)
        logging.info('[ Train flags ] use_trt             = %d', FLAGS.use_trt)
        logging.info('[ Train flags ] trt_max_batch_size  = %d',
                     FLAGS.trt_max_batch_size)
        logging.info('[ Train flags ] trt_precision       = %s',
                     FLAGS.trt_precision)
        logging.info('[ Train flags ] shuffle_buffer_size = %d',
                     FLAGS.shuffle_buffer_size)
        logging.info('[ Train flags ] shuffle_examples    = %d',
                     FLAGS.shuffle_examples)
        logging.info('[ Train flags ] export path         = %s',
                     FLAGS.export_path)
        logging.info('[ Train flags ] num_gpus_train      = %d', hvd.size())

        # From dual_net.py
        logging.info('[ d_net flags ] work_dir            = %s',
                     FLAGS.work_dir)
        logging.info('[ d_net flags ] train_batch_size    = %d',
                     FLAGS.train_batch_size)
        logging.info('[ d_net flags ] lr_rates            = %s',
                     FLAGS.lr_rates)
        logging.info('[ d_net flags ] lr_boundaries       = %s',
                     FLAGS.lr_boundaries)
        logging.info('[ d_net flags ] l2_strength         = %s',
                     FLAGS.l2_strength)
        logging.info('[ d_net flags ] conv_width          = %d',
                     FLAGS.conv_width)
        logging.info('[ d_net flags ] fc_width            = %d',
                     FLAGS.fc_width)
        logging.info('[ d_net flags ] trunk_layers        = %d',
                     FLAGS.trunk_layers)
        logging.info('[ d_net flags ] value_cost_weight   = %s',
                     FLAGS.value_cost_weight)
        logging.info('[ d_net flags ] summary_steps       = %d',
                     FLAGS.summary_steps)
        logging.info('[ d_net flags ] bool_features       = %d',
                     FLAGS.bool_features)
        logging.info('[ d_net flags ] input_features      = %s',
                     FLAGS.input_features)
        logging.info('[ d_net flags ] input_layout        = %s',
                     FLAGS.input_layout)

    # Training
    tf_records_ph = tf.placeholder(tf.string)
    data_iter = preprocessing.get_input_tensors(
        FLAGS.train_batch_size // hvd.size(),
        FLAGS.input_layout,
        tf_records_ph,
        filter_amount=FLAGS.filter_amount,
        shuffle_examples=FLAGS.shuffle_examples,
        shuffle_buffer_size=FLAGS.shuffle_buffer_size,
        random_rotation=True)

    features, labels = data_iter.get_next()
    train_op = dual_net.model_fn(features, labels, tf.estimator.ModeKeys.TRAIN,
                                 FLAGS.flag_values_dict(), True)
    sess = dual_net._get_session()

    # restore all from a checkpoint
    tf.train.Saver().restore(sess,
                             os.path.join(FLAGS.work_dir, 'model.ckpt-5672'))

    return TrainState(sess, train_op, data_iter, tf_records_ph)
Exemplo n.º 5
0
import tensorflow as tf
import dual_net

save_file = '../epoch_12_step_6924.data-00000-of-00001'
dest_file = '../epoch_12_step_6924.data-00000-of-00001-upgrade'
features, labels = dual_net.get_inference_input()
dual_net.model_fn(features, labels, tf.estimator.ModeKeys.PREDICT,
                  dual_net.get_default_hyperparams())
sess = tf.Session()

# retrieve the global step as a python value
ckpt = tf.train.load_checkpoint(save_file)
global_step_value = ckpt.get_tensor('global_step')


# restore all saved weights, except global_step
from tensorflow.python.framework import meta_graph
meta_graph_def = meta_graph.read_meta_graph_file(save_file  '.meta')
stored_var_names = set([n.name
                        for n in meta_graph_def.graph_def.node
                        if n.op == 'VariableV2'])
print(stored_var_names)
stored_var_names.remove('global_step')
var_list = [v for v in tf.global_variables()
            if v.op.name in stored_var_names]
tf.train.Saver(var_list=var_list).restore(sess, save_file)

# manually set the global step
global_step_tensor = tf.train.get_or_create_global_step()
assign_op = tf.assign(global_step_tensor, global_step_value)
sess.run(assign_op)