Пример #1
0
 def _classify(params):
   ret = { }
   output_dim = {}
   hash = hashlib.new('ripemd160')
   hash.update(json.dumps(params))
   hash = hash.hexdigest()
   for k in params:
     try:
       params[k] = numpy.asarray(params[k], dtype='float32')
       if k != 'data':
         output_dim[k] = network.n_out[k] # = [network.n_in,2] if k == 'data' else network.n_out[k]
     except Exception:
       if k != 'data' and not k in network.n_out:
         ret['error'] = 'unknown target: %s' % k
       else:
         ret['error'] = 'unable to convert %s to an array from value %s' % (k,str(params[k]))
       break
   if not 'error' in ret:
     data = StaticDataset(data=[params], output_dim=output_dim)
     data.init_seq_order()
     try:
       data = StaticDataset(data=[params], output_dim=output_dim)
       data.init_seq_order()
     except Exception:
       ret['error'] = "invalid data: %s" % params
     else:
       batches = data.generate_batches(recurrent_net=network.recurrent,
                                       batch_size=sys.maxsize, max_seqs=1)
       if not hash in workers:
         workers[hash] = ClassificationTaskThread(network, devices, data, batches)
         workers[hash].json_params = params
         print("worker started:", hash, file=log.v3)
       ret['result'] = { 'hash' : hash }
   return ret
Пример #2
0
    def classify_in_background(self):
        while True:
            requests = []
            # fetch first request
            r = yield self.classification_queue.get()
            requests.append(r)
            # grab all other waiting requests
            try:
                while True:
                    requests.append(self.classification_queue.get_nowait())
            except QueueEmpty:
                pass

            output_dim = {}
            # Do dataset creation and classification.
            dataset = StaticDataset(data=[r.data for r in requests],
                                    output_dim=output_dim)
            dataset.init_seq_order()
            batches = dataset.generate_batches(
                recurrent_net=self.engine.network.recurrent,
                batch_size=self.batch_size,
                max_seqs=self.max_seqs)

            with (yield self.lock.acquire()):
                ctt = ForwardTaskThread(self.engine.network, self.devices,
                                        dataset, batches)
                yield ctt.join()

            try:
                for i in range(dataset.num_seqs):
                    requests[i].future.set_result(ctt.result[i])
                    self.classification_queue.task_done()
            except Exception as e:
                print('exception', e)
                raise