예제 #1
0
def segmented_document(lines, segmenter):
    doc = ld.Document()
    global docid
    docid += 1
    doc.document_id = 'doc#%s' % docid
    segs = all_segments(lines, segmenter)
    doc.segments[:] = segs
    return doc
예제 #2
0
 def parse_document(msg):
     doc = ld.Document()
     doc.ParseFromString(msg)
     return doc
예제 #3
0
def server(args, model, tokenizer, protobuf=False, verbose=1):
    batchsz = int(args.per_gpu_eval_batch_size * max(1, args.n_gpu))
    args.eval_batch_size = batchsz
    eof = False
    if args.kafka:
        from time import sleep
        from kafka import KafkaProducer, KafkaConsumer
        producer = KafkaProducer(bootstrap_servers=makelist(args.kafka_bootstrap), api_version=args.kafka_api_version)
        def parse_document(msg):
            doc = ld.Document()
            doc.ParseFromString(msg)
            return doc
        consumer = KafkaConsumer(args.kafka_in_topic, bootstrap_servers=makelist(args.kafka_bootstrap), api_version=args.kafka_api_version) # , value_deserializer=parse_document
        sys.stderr.write('kafka server running until Ctrl-C\n')
        try:
            for msg in consumer:
                # log('kafka msg: ' + str(msg))
                doc = ld.Document()
                doc.ParseFromString(msg.value)
                value = labeldoc(doc, args, model, tokenizer).SerializeToString()
                future = producer.send(args.kafka_out_topic, key=doc.document_id.encode('utf-8'), value=value)
                # Block for 'synchronous' sends
                try:
                    record_metadata = future.get(timeout=10)
                    log("wrote " + str(record_metadata) + " value: %s" % value)
                except Exception as e:
                    log("exception: %s" % e)
                    raise e
        except KeyboardInterrupt:
            if consumer is not None: consumer.close()
            if producer is not None: producer.close()
        return
    elif args.proto:
        import protostream
        stdout = os.fdopen(sys.stdout.fileno(), "wb", closefd=False) # or sys.stdout.buffer?
        stdin = os.fdopen(sys.stdin.fileno(), "rb", closefd=False) # or sys.stdin.buffer?
        with protostream.open(mode='wb', fileobj=stdout) as ostream:
            for doc in protostream.parse(stdin, ld.Document):
                ostream.write(labeldoc(doc, args, model, tokenizer))
        return
    else:
        docid = 0
        from kafka_args import label_gap, label_str, logits_str
        while not eof:
            lines = []
            try:
                while True:
                    line = input()
                    assert line is not None
                    line = line.strip()
                    if len(line) == 0: break
                    lines.append(line)
                    if len(lines) >= batchsz: break
            except EOFError:
                eof = True
            if len(lines) > 0:
                doc = ld.Document()
                doc.document_id = 'doc#%s' % docid
                for line in lines:
                    doc.segments.append(line)
                ldoc = labeldoc(doc, args, model, tokenizer)
                for i, l in enumerate(ldoc.labels):
                    label, gap = label_gap(l.logits)
                    labelstr = label_str(label)
                    outserver('%s(+%s)[%s] %s' % (labelstr, rounded(gap), logits_str(l.logits), with_explanation(l.words, lines[i], labelstr)))
                docid += 1
                lines = []
            elif eof:
                outserver('')
예제 #4
0
def server(args, model, tokenizer):
    verbose = args.verbose
    model.to(args.device)
    model.eval()
    import explanation
    batchsz = int(args.per_gpu_eval_batch_size * max(1, args.n_gpu))
    args.eval_batch_size = batchsz
    eof = False
    if args.kafka:
        from time import sleep
        from kafka import KafkaProducer, KafkaConsumer
        producer = KafkaProducer(bootstrap_servers=makelist(
            args.kafka_bootstrap),
                                 api_version=args.kafka_api_version)

        def parse_document(msg):
            doc = ld.Document()
            doc.ParseFromString(msg)
            return doc

        consumer = KafkaConsumer(args.kafka_in_topic,
                                 bootstrap_servers=makelist(
                                     args.kafka_bootstrap),
                                 api_version=args.kafka_api_version
                                 )  # , value_deserializer=parse_document
        sys.stderr.write('kafka server running until Ctrl-C\n')
        try:
            for msg in consumer:
                # log('kafka msg: ' + str(msg))
                doc = ld.Document()
                doc.ParseFromString(msg.value)
                value = labeldoc(doc, args, model,
                                 tokenizer).SerializeToString()
                future = producer.send(args.kafka_out_topic,
                                       key=doc.document_id.encode('utf-8'),
                                       value=value)
                # Block for 'synchronous' sends
                try:
                    record_metadata = future.get(timeout=10)
                    log("wrote " + str(record_metadata) + " value: %s" % value)
                except Exception as e:
                    log("exception: %s" % e)
                    raise e
        except KeyboardInterrupt:
            if consumer is not None: consumer.close()
            if producer is not None: producer.close()
        return
    elif args.proto:
        import protostream
        stdout = os.fdopen(sys.stdout.fileno(), "wb",
                           closefd=False)  # or sys.stdout.buffer?
        stdin = os.fdopen(sys.stdin.fileno(), "rb",
                          closefd=False)  # or sys.stdin.buffer?
        with protostream.open(mode='wb', fileobj=stdout) as ostream:
            for doc in protostream.parse(stdin, ld.Document):
                ostream.write(labeldoc(doc, args, model, tokenizer))
        return
    else:
        from kafka_args import label_gap, label_str, logits_str
        import send_document
        for doc in send_document.stdin_docs(args):
            ldoc = labeldoc(doc, args, model, tokenizer)
            for i, l in enumerate(ldoc.labels):
                label, gap = label_gap(l.logits)
                labelstr = label_str(label)
                explained = explanation.with_explanation(
                    l.words, doc.segments[i], labelstr, args)
                if args.brief_explanation:
                    out = '%s %s\t%s' % (label, logits_str(
                        l.logits), explained)
                else:
                    out = '%s(+%s)[%s] %s' % (labelstr, rounded(gap),
                                              logits_str(l.logits), explained)
                outserver(out)
            outserver('')