def const_model_inference_fn(features): """Builds the model graph with weights marked as constant. This improves TPU inference performance because it prevents the weights being transferred to the TPU every call to Session.run(). Returns: (policy_output, value_output, logits) tuple of tensors. """ def custom_getter(getter, name, *args, **kwargs): with tf.control_dependencies(None): return tf.guarantee_const( getter(name, *args, **kwargs), name=name+"/GuaranteeConst") with tf.variable_scope("", custom_getter=custom_getter): return dual_net.model_inference_fn(features, False)
def _locked_load_model(self, path): tf.reset_default_graph() if path[-3:] == ".pb": graph_def = tf.GraphDef() with tf.gfile.FastGFile(path, 'rb') as f: graph_def.ParseFromString(f.read()) with self._sess.graph.as_default(): self._outputs = tf.import_graph_def( graph_def, input_map={'pos_tensor': self._feature_placeholder}, return_elements=['policy_output:0', 'value_output:0']) else: with self._sess.graph.as_default(): self._outputs = dual_net.model_inference_fn( self._feature_placeholder, training=False) tf.train.Saver().restore(self._sess, path)
def main(argv): features, labels = dual_net.get_inference_input() tf_tensors = dual_net.model_inference_fn(features, False) if len(tf_tensors) != 4: print("oneoffs/embeddings.py requires you modify") print("dual_net.model_inference_fn and add a fourth param") sys.exit(1) p_out, v_out, logits, shared = tf_tensors predictions = {'shared': shared} sess = tf.Session() tf.train.Saver().restore(sess, FLAGS.model) try: progress = tqdm(get_files()) embeddings = [] metadata = [] for i, f in enumerate(progress): short_f = os.path.basename(f) short_f = short_f.replace('minigo-cc-evaluator', '') short_f = short_f.replace('-000', '-') progress.set_description('Processing %s' % short_f) processed = [] for idx, p in enumerate(sgf_wrapper.replay_sgf_file(f)): if idx < FLAGS.first: continue if idx > FLAGS.last: break if idx % FLAGS.every != 0: continue processed.append(features_lib.extract_features(p.position)) metadata.append((f, idx)) if len(processed) > 0: # If len(processed) gets too large may have to chunk. res = sess.run(predictions, feed_dict={features: processed}) for r in res['shared']: embeddings.append(r.flatten()) except: # Raise shows us the error but only after the finally block executes. raise finally: with open(FLAGS.embedding_file, 'wb') as pickle_file: pickle.dump([metadata, np.array(embeddings)], pickle_file)
def main(unused_argv): in_path = FLAGS.in_path out_path = FLAGS.out_path assert tf.gfile.Exists(in_path) # TODO(amj): Why does ensure_dir_exists skip gs paths? #tf.gfile.MakeDirs(os.path.dirname(out_path)) #assert tf.gfile.Exists(os.path.dirname(out_path)) policy_err = [] value_err = [] print() with tf.python_io.TFRecordWriter(out_path, OPTS) as writer: ds_iter = preprocessing.get_input_tensors(FLAGS.batch_size, [in_path], shuffle_examples=False, random_rotation=False, filter_amount=1.0) with tf.Session() as sess: features, labels = ds_iter p_in = labels['pi_tensor'] v_in = labels['value_tensor'] p_out, v_out, logits = dual_net.model_inference_fn( features, False, FLAGS.flag_values_dict()) tf.train.Saver().restore(sess, FLAGS.model) # TODO(seth): Add policy entropy. p_err = tf.nn.softmax_cross_entropy_with_logits_v2( logits=logits, labels=tf.stop_gradient(p_in)) v_err = tf.square(v_out - v_in) for _ in tqdm(itertools.count(1)): try: # Undo cast in batch_parse_tf_example. x_in = tf.cast(features, tf.int8) x, pi, val, pi_err, val_err = sess.run( [x_in, p_out, v_out, p_err, v_err]) for i, (x_i, pi_i, val_i) in enumerate(zip(x, pi, val)): # NOTE: The teacher's policy has much higher entropy # Than the Self-play policy labels which are mostly 0 # expect that resulting file is 3-5x larger. r = preprocessing.make_tf_example(x_i, pi_i, val_i) serialized = r.SerializeToString() writer.write(serialized) policy_err.extend(pi_err) value_err.extend(val_err) except tf.errors.OutOfRangeError: print() print("Breaking OutOfRangeError") break print("Counts", len(policy_err), len(value_err)) test() plt.subplot(121) n, bins, patches = plt.hist(policy_err, 40) plt.title('Policy Error histogram') plt.subplot(122) n, bins, patches = plt.hist(value_err, 40) plt.title('Value Error') plt.show()
def loop_body(a, unused_b): """Loop body for the tf.while_loop op. Args: a: a constant 0 unused_b: a string placeholder (to satisfy the requirement that a while_loop's condition and body accept the same args as the loop returns). Returns: A TensorFlow subgraph. """ # Request features features. raw_response = tf.contrib.rpc.rpc( address=config.address, method=config.get_features_method, request="", protocol="grpc", fail_fast=True, timeout_in_ms=0, name="get_features") # Decode features from a proto to a flat tensor. _, (batch_id, flat_features) = decode_proto_op.decode_proto( bytes=raw_response, message_type='minigo.GetFeaturesResponse', field_names=['batch_id', 'features'], output_types=[dtypes.int32, dtypes.float32], descriptor_source=config.descriptor_path, name="decode_raw_features") # Reshape flat features. features = tf.reshape( flat_features, [-1, go.N, go.N, features_lib.NEW_FEATURES_PLANES], name="unflatten_features") # Run inference. policy_output, value_output, _ = dual_net.model_inference_fn( features, False) # Flatten model outputs. flat_policy = tf.reshape(policy_output, [-1], name="flatten_policy") flat_value = value_output # value_output is already flat. # Encode outputs from flat tensors to a proto. request_tensors = encode_proto_op.encode_proto( message_type='minigo.PutOutputsRequest', field_names=['batch_id', 'policy', 'value'], sizes=[[1, policy_output_size, value_output_size]], values=[[batch_id], [flat_policy], [flat_value]], descriptor_source=config.descriptor_path, name="encode_outputs") # Send outputs. response = tf.contrib.rpc.rpc( address=config.address, method=config.put_outputs_method, request=request_tensors, protocol="grpc", fail_fast=True, timeout_in_ms=0, name="put_outputs") return a, response[0]