Exemplo n.º 1
0
def const_model_inference_fn(features):
    """Builds the model graph with weights marked as constant.

    This improves TPU inference performance because it prevents the weights
    being transferred to the TPU every call to Session.run().

    Returns:
        (policy_output, value_output, logits) tuple of tensors.
    """
    def custom_getter(getter, name, *args, **kwargs):
        with tf.control_dependencies(None):
            return tf.guarantee_const(
                getter(name, *args, **kwargs), name=name+"/GuaranteeConst")
    with tf.variable_scope("", custom_getter=custom_getter):
        return dual_net.model_inference_fn(features, False)
Exemplo n.º 2
0
    def _locked_load_model(self, path):
        tf.reset_default_graph()

        if path[-3:] == ".pb":
            graph_def = tf.GraphDef()
            with tf.gfile.FastGFile(path, 'rb') as f:
                graph_def.ParseFromString(f.read())
            with self._sess.graph.as_default():
                self._outputs = tf.import_graph_def(
                    graph_def,
                    input_map={'pos_tensor': self._feature_placeholder},
                    return_elements=['policy_output:0', 'value_output:0'])
        else:
            with self._sess.graph.as_default():
                self._outputs = dual_net.model_inference_fn(
                    self._feature_placeholder, training=False)
                tf.train.Saver().restore(self._sess, path)
Exemplo n.º 3
0
def main(argv):
    features, labels = dual_net.get_inference_input()
    tf_tensors = dual_net.model_inference_fn(features, False)
    if len(tf_tensors) != 4:
        print("oneoffs/embeddings.py requires you modify")
        print("dual_net.model_inference_fn and add a fourth param")
        sys.exit(1)

    p_out, v_out, logits, shared = tf_tensors
    predictions = {'shared': shared}

    sess = tf.Session()
    tf.train.Saver().restore(sess, FLAGS.model)

    try:
        progress = tqdm(get_files())
        embeddings = []
        metadata = []
        for i, f in enumerate(progress):
            short_f = os.path.basename(f)
            short_f = short_f.replace('minigo-cc-evaluator', '')
            short_f = short_f.replace('-000', '-')
            progress.set_description('Processing %s' % short_f)

            processed = []
            for idx, p in enumerate(sgf_wrapper.replay_sgf_file(f)):
                if idx < FLAGS.first: continue
                if idx > FLAGS.last: break
                if idx % FLAGS.every != 0: continue

                processed.append(features_lib.extract_features(p.position))
                metadata.append((f, idx))

            if len(processed) > 0:
                # If len(processed) gets too large may have to chunk.
                res = sess.run(predictions, feed_dict={features: processed})
                for r in res['shared']:
                    embeddings.append(r.flatten())
    except:
        # Raise shows us the error but only after the finally block executes.
        raise
    finally:
        with open(FLAGS.embedding_file, 'wb') as pickle_file:
            pickle.dump([metadata, np.array(embeddings)], pickle_file)
Exemplo n.º 4
0
def main(unused_argv):
    in_path = FLAGS.in_path
    out_path = FLAGS.out_path

    assert tf.gfile.Exists(in_path)
    # TODO(amj): Why does ensure_dir_exists skip gs paths?
    #tf.gfile.MakeDirs(os.path.dirname(out_path))
    #assert tf.gfile.Exists(os.path.dirname(out_path))

    policy_err = []
    value_err = []

    print()
    with tf.python_io.TFRecordWriter(out_path, OPTS) as writer:
        ds_iter = preprocessing.get_input_tensors(FLAGS.batch_size, [in_path],
                                                  shuffle_examples=False,
                                                  random_rotation=False,
                                                  filter_amount=1.0)

        with tf.Session() as sess:
            features, labels = ds_iter
            p_in = labels['pi_tensor']
            v_in = labels['value_tensor']

            p_out, v_out, logits = dual_net.model_inference_fn(
                features, False, FLAGS.flag_values_dict())
            tf.train.Saver().restore(sess, FLAGS.model)

            # TODO(seth): Add policy entropy.

            p_err = tf.nn.softmax_cross_entropy_with_logits_v2(
                logits=logits, labels=tf.stop_gradient(p_in))
            v_err = tf.square(v_out - v_in)

            for _ in tqdm(itertools.count(1)):
                try:
                    # Undo cast in batch_parse_tf_example.
                    x_in = tf.cast(features, tf.int8)

                    x, pi, val, pi_err, val_err = sess.run(
                        [x_in, p_out, v_out, p_err, v_err])

                    for i, (x_i, pi_i, val_i) in enumerate(zip(x, pi, val)):
                        # NOTE: The teacher's policy has much higher entropy
                        # Than the Self-play policy labels which are mostly 0
                        # expect that resulting file is 3-5x larger.

                        r = preprocessing.make_tf_example(x_i, pi_i, val_i)
                        serialized = r.SerializeToString()
                        writer.write(serialized)

                    policy_err.extend(pi_err)
                    value_err.extend(val_err)

                except tf.errors.OutOfRangeError:
                    print()
                    print("Breaking OutOfRangeError")
                    break

    print("Counts", len(policy_err), len(value_err))
    test()

    plt.subplot(121)
    n, bins, patches = plt.hist(policy_err, 40)
    plt.title('Policy Error histogram')

    plt.subplot(122)
    n, bins, patches = plt.hist(value_err, 40)
    plt.title('Value Error')

    plt.show()
Exemplo n.º 5
0
    def loop_body(a, unused_b):
        """Loop body for the tf.while_loop op.

        Args:
            a: a constant 0
            unused_b: a string placeholder (to satisfy the requirement that a
                      while_loop's condition and body accept the same args as
                      the loop returns).

        Returns:
            A TensorFlow subgraph.
        """

        # Request features features.
        raw_response = tf.contrib.rpc.rpc(
            address=config.address,
            method=config.get_features_method,
            request="",
            protocol="grpc",
            fail_fast=True,
            timeout_in_ms=0,
            name="get_features")

        # Decode features from a proto to a flat tensor.
        _, (batch_id, flat_features) = decode_proto_op.decode_proto(
            bytes=raw_response,
            message_type='minigo.GetFeaturesResponse',
            field_names=['batch_id', 'features'],
            output_types=[dtypes.int32, dtypes.float32],
            descriptor_source=config.descriptor_path,
            name="decode_raw_features")

        # Reshape flat features.
        features = tf.reshape(
            flat_features, [-1, go.N, go.N, features_lib.NEW_FEATURES_PLANES],
            name="unflatten_features")

        # Run inference.
        policy_output, value_output, _ = dual_net.model_inference_fn(
            features, False)

        # Flatten model outputs.
        flat_policy = tf.reshape(policy_output, [-1], name="flatten_policy")
        flat_value = value_output  # value_output is already flat.

        # Encode outputs from flat tensors to a proto.
        request_tensors = encode_proto_op.encode_proto(
            message_type='minigo.PutOutputsRequest',
            field_names=['batch_id', 'policy', 'value'],
            sizes=[[1, policy_output_size, value_output_size]],
            values=[[batch_id], [flat_policy], [flat_value]],
            descriptor_source=config.descriptor_path,
            name="encode_outputs")

        # Send outputs.
        response = tf.contrib.rpc.rpc(
            address=config.address,
            method=config.put_outputs_method,
            request=request_tensors,
            protocol="grpc",
            fail_fast=True,
            timeout_in_ms=0,
            name="put_outputs")

        return a, response[0]