예제 #1
0
    def _worker_thread(self):
        num_board_features = go.N * go.N * features_lib.NEW_FEATURES_PLANES

        print("waiting for model")
        while self._running and not self.model_available.wait(1):
            pass

        print("running worker")
        while self._running:
            features_response = self.stub.GetFeatures(
                inference_service_pb2.GetFeaturesRequest())
            all_features = features_response.features

            features = []
            num_features = self.batch_size * num_board_features
            for i in range(FLAGS.parallel_tpus):
                begin = i * num_features
                end = begin + num_features
                x = np.frombuffer(
                    all_features, dtype=np.int8, count=num_features, offset=begin)
                x = x.reshape([self.batch_size, go.N, go.N,
                               features_lib.NEW_FEATURES_PLANES])
                features.append(x)

            try:
                self.lock.acquire_read()
                outputs = self.sess.run(self.outputs,
                                        {self.feature_placeholders: features})
                # Make a local copy of self.model_path while this worker has
                # the read lock.
                local_model_path = self.model_path
            finally:
                self.lock.release_read()

            flat_policy = []
            value = []
            for x in outputs:
                flat_policy.extend(x[0])
                value.extend(x[1])

            put_outputs_request = inference_service_pb2.PutOutputsRequest(
                 batch_id=features_response.batch_id,
                 policy=np.concatenate(flat_policy), value=value,
                 model_path=local_model_path)
            self.stub.PutOutputs(put_outputs_request)
예제 #2
0
    def _worker_thread(self, thread_id):
        print("waiting for model")
        while self._running and not self.sess.model_available.wait(1):
            pass

        print("running worker", thread_id)
        while self._running:
            features_response = self.stub.GetFeatures(
                inference_service_pb2.GetFeaturesRequest())

            policy, value, model_path = self.sess.run(
                features_response.features)

            put_outputs_request = inference_service_pb2.PutOutputsRequest(
                batch_id=features_response.batch_id,
                policy=np.concatenate(policy), value=value,
                model_path=model_path)

            self.stub.PutOutputs(put_outputs_request)

        print("stopping worker", thread_id)
예제 #3
0
    def _worker_thread(self, thread_id):
        dbg("waiting for model")
        while self._running and not self.sess.model_available.wait(1):
            pass

        dbg("running worker", thread_id)
        # Don't start counting till we've warmed up.
        batches = 0
        batch_group_start = time.time()

        while self._running:
            features_response = self.stub.GetFeatures(
                inference_service_pb2.GetFeaturesRequest())

            if batches == 0:
                warm_start = time.time()

            policy, value, model_path = self.sess.run(
                features_response.features)

            put_outputs_request = inference_service_pb2.PutOutputsRequest(
                batch_id=features_response.batch_id,
                policy=np.concatenate(policy),
                value=value,
                model_path=model_path)

            self.stub.PutOutputs(put_outputs_request)

            batches += 1
            if batches % 100 == 0:
                end_time = time.time()
                batch_time = end_time - batch_group_start
                dbg("recent: {:.2f}s/{}, total: {:.2f}s/{} inferences".format(
                    batch_time, 100 * self.positions_per_inference,
                    end_time - warm_start,
                    batches * self.positions_per_inference))
                batch_group_start = end_time

        dbg("stopping worker", thread_id)