Beispiel #1
0
def convert(load_file, dest_file):
    from tensorflow.python.framework import meta_graph
    features, labels = dual_net.get_inference_input()
    dual_net.model_fn(features, labels, tf.estimator.ModeKeys.PREDICT)
    sess = tf.Session()

    # retrieve the global step as a python value
    ckpt = tf.train.load_checkpoint(load_file)
    global_step_value = ckpt.get_tensor('global_step')

    # restore all saved weights, except global_step
    meta_graph_def = meta_graph.read_meta_graph_file(load_file + '.meta')
    stored_var_names = set([
        n.name for n in meta_graph_def.graph_def.node if n.op == 'VariableV2'
    ])
    stored_var_names.remove('global_step')
    var_list = [
        v for v in tf.global_variables() if v.op.name in stored_var_names
    ]
    tf.train.Saver(var_list=var_list).restore(sess, load_file)

    # manually set the global step
    global_step_tensor = tf.train.get_or_create_global_step()
    assign_op = tf.assign(global_step_tensor, global_step_value)
    sess.run(assign_op)

    # export a new savedmodel that has the right global step type
    tf.train.Saver().save(sess, dest_file)
    sess.close()
    tf.reset_default_graph()
Beispiel #2
0
def main(argv):
    features, labels = dual_net.get_inference_input()
    tf_tensors = dual_net.model_inference_fn(features, False)
    if len(tf_tensors) != 4:
        print("oneoffs/embeddings.py requires you modify")
        print("dual_net.model_inference_fn and add a fourth param")
        sys.exit(1)

    p_out, v_out, logits, shared = tf_tensors
    predictions = {'shared': shared}

    sess = tf.Session()
    tf.train.Saver().restore(sess, FLAGS.model)

    try:
        progress = tqdm(get_files())
        embeddings = []
        metadata = []
        for i, f in enumerate(progress):
            short_f = os.path.basename(f)
            short_f = short_f.replace('minigo-cc-evaluator', '')
            short_f = short_f.replace('-000', '-')
            progress.set_description('Processing %s' % short_f)

            processed = []
            for idx, p in enumerate(sgf_wrapper.replay_sgf_file(f)):
                if idx < FLAGS.first: continue
                if idx > FLAGS.last: break
                if idx % FLAGS.every != 0: continue

                processed.append(features_lib.extract_features(p.position))
                metadata.append((f, idx))

            if len(processed) > 0:
                # If len(processed) gets too large may have to chunk.
                res = sess.run(predictions, feed_dict={features: processed})
                for r in res['shared']:
                    embeddings.append(r.flatten())
    except:
        # Raise shows us the error but only after the finally block executes.
        raise
    finally:
        with open(FLAGS.embedding_file, 'wb') as pickle_file:
            pickle.dump([metadata, np.array(embeddings)], pickle_file)
Beispiel #3
0
def backfill():
    models = [m[1] for m in fsdb.get_models()]

    import dual_net
    import tensorflow as tf
    from tqdm import tqdm
    features, labels = dual_net.get_inference_input()
    dual_net.model_fn(features, labels, tf.estimator.ModeKeys.PREDICT,
                      dual_net.get_default_hyperparams())

    for model_name in tqdm(models):
        if model_name.endswith('-upgrade'):
            continue
        try:
            load_file = os.path.join(fsdb.models_dir(), model_name)
            dest_file = os.path.join(fsdb.models_dir(), model_name)
            main.convert(load_file, dest_file)
        except:
            print('failed on', model_name)
            continue
Beispiel #4
0
def swa():
    path_base = fsdb.models_dir()
    model_names = [
        "000393-lincoln",
        "000390-indus",
        "000404-hannibal",
        "000447-hawke",
        "000426-grief",
        "000431-lion",
        "000428-invincible",
        "000303-olympus",
        "000291-superb",
        "000454-victorious",
    ]
    model_names = model_names[:FLAGS.count]

    model_paths = [os.path.join(path_base, m) for m in model_names]

    # construct the graph
    features, labels = dual_net.get_inference_input()
    dual_net.model_fn(features, labels, tf.estimator.ModeKeys.PREDICT)

    # restore all saved weights
    meta_graph_def = meta_graph.read_meta_graph_file(model_paths[0] + '.meta')
    stored_var_names = set([
        n.name for n in meta_graph_def.graph_def.node if n.op == 'VariableV2'
    ])

    var_list = [
        v for v in tf.global_variables() if v.op.name in stored_var_names
    ]
    var_list.sort(key=lambda v: v.op.name)

    print(stored_var_names)
    print(len(stored_var_names), len(var_list))

    sessions = [tf.Session() for _ in model_paths]
    saver = tf.train.Saver()
    for sess, model_path in zip(sessions, model_paths):
        saver.restore(sess, model_path)

    # Load all VariableV2s for each model.
    values = [sess.run(var_list) for sess in sessions]

    # Iterate over all variables average values from all models.
    all_assign = []
    for var, vals in zip(var_list, zip(*values)):
        print("{}x {}".format(len(vals), var))
        if var.name == "global_step:0":
            avg = vals[0]
            for val in vals:
                avg = tf.maximum(avg, val)
        else:
            avg = tf.add_n(vals) / len(vals)
            continue

        all_assign.append(tf.assign(var, avg))

    # Run all asign ops on an existing model (which has other ops and graph).
    sess = sessions[0]
    sess.run(all_assign)

    # Export a new saved model.
    ensure_dir_exists(FLAGS.data_dir)
    dest_path = os.path.join(FLAGS.data_dir, "swa-" + str(FLAGS.count))
    saver.save(sess, dest_path)
Beispiel #5
0
import tensorflow as tf
import dual_net

save_file = '../epoch_12_step_6924.data-00000-of-00001'
dest_file = '../epoch_12_step_6924.data-00000-of-00001-upgrade'
features, labels = dual_net.get_inference_input()
dual_net.model_fn(features, labels, tf.estimator.ModeKeys.PREDICT,
                  dual_net.get_default_hyperparams())
sess = tf.Session()

# retrieve the global step as a python value
ckpt = tf.train.load_checkpoint(save_file)
global_step_value = ckpt.get_tensor('global_step')


# restore all saved weights, except global_step
from tensorflow.python.framework import meta_graph
meta_graph_def = meta_graph.read_meta_graph_file(save_file  '.meta')
stored_var_names = set([n.name
                        for n in meta_graph_def.graph_def.node
                        if n.op == 'VariableV2'])
print(stored_var_names)
stored_var_names.remove('global_step')
var_list = [v for v in tf.global_variables()
            if v.op.name in stored_var_names]
tf.train.Saver(var_list=var_list).restore(sess, save_file)

# manually set the global step
global_step_tensor = tf.train.get_or_create_global_step()
assign_op = tf.assign(global_step_tensor, global_step_value)
sess.run(assign_op)