def _worker(
    reader: DatasetReader,
    input_queue: Queue,
    output_queue: Queue,
    num_active_workers: Value,
    num_inflight_items: Value,
    worker_id: int,
) -> None:
    """
    A worker that pulls filenames off the input queue, uses the dataset reader
    to read them, and places the generated instances on the output queue.  When
    there are no filenames left on the input queue, it decrements
    num_active_workers to signal completion.
    """
    logger.info(f"Reader worker: {worker_id} PID: {os.getpid()}")
    # Keep going until you get a file_path that's None.
    while True:
        file_path = input_queue.get()
        if file_path is None:
            # It's important that we close and join the queue here before
            # decrementing num_active_workers. Otherwise our parent may join us
            # before the queue's feeder thread has passed all buffered items to
            # the underlying pipe resulting in a deadlock.
            #
            # See:
            # https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#pipes-and-queues
            # https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#programming-guidelines
            output_queue.close()
            output_queue.join_thread()
            # Decrementing is not atomic.
            # See https://docs.python.org/2/library/multiprocessing.html#multiprocessing.Value.
            with num_active_workers.get_lock():
                num_active_workers.value -= 1
            logger.info(f"Reader worker {worker_id} finished")
            break

        logger.info(f"reading instances from {file_path}")
        for instance in reader.read(file_path):
            with num_inflight_items.get_lock():
                num_inflight_items.value += 1
            output_queue.put(instance)
示例#2
0
def main(args):
    if args.labels:
        data = []
        with open(args.data_dir, encoding="utf-8") as f:
            for line in csv.reader(f, delimiter="\t"):
                data.append(line)
            text, labels = list(zip(*data[1:]))
    else:
        text = []
        with open(args.data_dir, encoding="utf-8") as f:
            for line in f.readlines():
                text.append(line.strip())
        labels = None

    if isinstance(text, tuple):
        text = list(text)

    if "imdb" in args.data_dir or "IMDB" in args.data_dir:
        text = [clean_for_imdb(t) for t in text]

    logger.info("Do back-translation for {} sentences".format(len(text)))

    if args.gpus is not None and len(args.gpus) > 1:
        logger.info("Use Multiple GPUs: {}".format(", ".join([str(i) for i in args.gpus])))
        split_point = len(text) // len(args.gpus)

        text_splitted = []
        for gpu_id in args.gpus:
            text_splitted.append(text[gpu_id * split_point : (gpu_id + 1) * split_point])
            if gpu_id == len(args.gpus) - 1:
                text_splitted[-1] += text[(gpu_id + 1) * split_point :]
        assert sum(len(s) for s in text_splitted) == len(text)

        set_start_method("spawn")
        q = Queue()

        procs = []
        for i in range(len(args.gpus)):
            proc = Process(target=multi_translate, args=(args, i, text_splitted[i], q))
            procs.append(proc)
            proc.start()

        q_result = []
        for p in procs:
            q_result.append(q.get())

        back_translated_docs = []
        for doc_split in sorted(q_result):
            back_translated_docs += doc_split[1]

        q.close()
        q.join_thread()

        for proc in procs:
            proc.join()
    else:
        if args.gpus is not None:
            gpu = args.gpus[0]
            logger.info("Use only one GPU: {}".format(gpu))
            back_translated_docs = translate(args, text, args.gpus[0])[1]
        else:
            logger.info("Use cpu")
            back_translated_docs = translate(args, text)

    output_file_name = "bt_" + os.path.basename(args.data_dir)
    output_dir = os.path.join(args.output_dir, output_file_name)

    folder_name = os.path.dirname(output_dir)
    if not os.path.isdir(folder_name):
        os.makedirs(folder_name)

    if args.return_sentence_pair:
        # Save original sentence pair
        filename, ext = os.path.splitext(output_dir)
        with open(filename + ".pickle", "wb") as f:
            pickle.dump(back_translated_docs, f)

        # Save back-translated sentences
        bt_doc = [" ".join(list(zip(*d))[1]) for d in back_translated_docs]
        with open(output_dir, "wt") as f:
            if labels is not None:
                tsv_writer = csv.writer(f, delimiter="\t")
                tsv_writer.writerow(data[0])
                for line, label in zip(bt_doc, labels):
                    tsv_writer.writerow([line, label])
            else:
                for line in bt_doc:
                    f.write(line)
                    f.write('\n')

        # Save cross sentences
        new_back_translated_docs = []
        for doc in back_translated_docs:
            new_doc = []
            for j, sent in enumerate(doc):
                if j % 2 == 0:
                    new_doc.append(sent)
                else:
                    new_doc.append(sent[::-1])
            new_back_translated_docs.append(new_doc)
        new_docs1, new_docs2 = [], []
        for doc in new_back_translated_docs:
            n1, n2 = list(zip(*doc))
            new_docs1.append(" ".join(n1))
            new_docs2.append(" ".join(n2))
        
        filename, ext = os.path.splitext(output_dir)
        with open(filename + "_pair1" + ext, "wt") as f:
            if labels is not None:
                tsv_writer = csv.writer(f, delimiter="\t")
                tsv_writer.writerow(data[0])
                for line, label in zip(new_docs1, labels):
                    tsv_writer.writerow([line, label])
            else:
                for line in new_docs1:
                    f.write(line)
                    f.write('\n')
        with open(filename + "_pair2" + ext, "wt") as f:
            if labels is not None:
                tsv_writer = csv.writer(f, delimiter="\t")
                tsv_writer.writerow(data[0])
                for line, label in zip(new_docs2, labels):
                    tsv_writer.writerow([line, label])
            else:
                for line in new_docs2:
                    f.write(line)
                    f.write('\n')
    else:
        with open(output_dir, "wt") as f:
            if labels is not None:
                tsv_writer = csv.writer(f, delimiter="\t")
                tsv_writer.writerow(data[0])
                for line, label in zip(back_translated_docs, labels):
                    tsv_writer.writerow([line, label])
            else:
                for line in back_translated_docs:
                    f.write(line)
                    f.write('\n')

    logger.info("Translated documents are saved in {}".format(output_dir))