class TimblClassifier(Classifier): def __init__(self, descriptor, inst_fname=None, inst_base_fname=None, options="", weight_func=None, server_log_fname=None): """ Create a new TimblClassifier instance @param descriptor: Descriptor instance @keyword inst_fname: name of file containing Timbl instances @keyword inst_base_fname: name of file containing Timbl instance base @keyword options: list of additional Timbl options, excluding -f, -m, +vo, +vdb, +vdi @keyword server_log_fname: filename for Timbl server log @param weight_func: weight function; defaults to entropy_weight """ self.no_rel = descriptor.no_rel self._init_server(descriptor, inst_fname, inst_base_fname, options, server_log_fname) self._init_client() self.weight_func = weight_func or entropy_weight def _init_server(self, descriptor, inst_fname, inst_base_fname, options, server_log_fname): options = timbl_options_string(descriptor, inst_fname=inst_fname, inst_base_fname=inst_base_fname, other=options) # Timbl server will automatically terminate when TimblServer object # dies, so keep a reference to it self._server = TimblServer(timbl_opts=options, server_log_fname=server_log_fname) self._server.start() def _init_client(self): self._client = TimblClient(self._server.port) self._client.connect() def classify(self, instances): """ adds predicted class and associated weight to instances @param instances: numpy.ndarray instance """ for inst in instances: # Assumes that last field in instance is the true class inst_str = "\t".join( self._to_str(value) for value in inst ) result = self._client.classify(inst_str) inst["pred_relation"] = result["CATEGORY"] # The Timbl client is lazy and does not automatically parse the # distribution string, so we use parse_distrib to obtain an # iterator over (class, count) pairs distribution = parse_distrib(result["DISTRIBUTION"]) inst["pred_weight"] = self.weight_func( category=result["CATEGORY"], distribution=distribution) def _to_str(self, value): # value can be a bool, number, ascii string or unicode string try: return str(value) except UnicodeEncodeError: return value.encode("utf-8")
class Test_TimblClient(unittest.TestCase): def setUp(self): if not SERVER: start_timbl_server() self.client = TimblClient(SERVER.port) self.client.connect() def test_disconnect(self): self.client.disconnect() self.assertRaises(TimblClientError, self.client.query) self.assertFalse(self.client.socket) def test_reconnect(self): self.client.reconnect() self.client.query() def test_connection_timeout(self): # send incomplete command so server does not reply self.client.socket.settimeout(1) self.assertRaises(TimblClientError, self.client.set, "-k") self.client.socket.settimeout(10) def test_query(self): # repeat multiple times, because recv in multiple parts occurs rarely for i in range(25): status = self.client.query() ## print status self.assertEqual(status["NEIGHBORS"], "1") def test_set(self): self.client.set("-k 10 -d IL") status = self.client.query() self.assertEqual(status["NEIGHBORS"], "10") self.assertEqual(status["DECAY"], "IL") self.client.set("-k 1 -d Z") def test_set_error(self): self.assertRaises(TimblClientError, self.client.set, "-w 1") def test_classify(self): """ Exhaustively test classification with any combination of the verbose output options +/-vdb (distribution), +/-vdi (distance) and +/-vn (neighbours). The +/-vk seems to be unsupported, as it cannot be "set" through the server """ self.client.set("-k10") for db in "+vdb -vdb".split(): for di in "+vdi -vdi".split(): for vn in "+vn -vn".split(): self.client.set(db + " " + di + " " + vn) for i, inst in enumerate(open(DATA_DIR + "/dimin.train")): if i > 10: break result = self.client.classify(inst) self.assertTrue(result.has_key("CATEGORY")) if db == "+vdb": self.assertTrue(result.has_key("DISTRIBUTION")) else: self.assertFalse(result.has_key("DISTRIBUTION")) if di == "+vdi": self.assertTrue(result.has_key("DISTANCE")) else: self.assertFalse(result.has_key("DISTANCE")) if vn == "+vn": self.assertTrue(result.has_key("NEIGHBOURS")) else: self.assertFalse(result.has_key("NEIGHBOURS")) self.client.set("-k1 -vdb -vdi -vn") def test_classify_error(self): self.assertRaises(TimblClientError, self.client.classify, "x, x, x, x") def test_log(self): # quick & global config of logging system so output of loggers # goes to stdout logging.basicConfig(level=logging.DEBUG, format="%(levelname)-8s <%(name)s> :: %(message)s") self.client = TimblClient(SERVER.port, log_tag="test_log_client") self.client.connect() instances = open(DATA_DIR + "/dimin.train").readlines() for inst in instances[:2]: self.client.classify(inst) self.client.query() self.client.set("+vdb +vdi +vn") for inst in instances[:2]: self.client.classify(inst) try: self.client.classify("x, x") except TimblClientError: pass try: self.client.set("-w 1") except TimblClientError: pass self.client.disconnect() # global reset of logging level logging.getLogger().setLevel(logging.CRITICAL) def tearDown(self): self.client.disconnect()