Пример #1
0
def translate_lines(
        output_handler: sockeye.output_handler.OutputHandler,
        source_data: Iterable[str],
        translator: sockeye.inference.Translator) -> Tuple[int, float]:
    """
    Translates each line from source_data, calling output handler for each result.

    :param output_handler: A handler that will be called once with the output of each translation.
    :param source_data: A enumerable list of source sentences that will be translated.
    :param translator: The translator that will be used for each line of input.
    :return: The number of lines translated, and the total time taken.
    """

    i = 0
    total_time = 0.0
    for i, line in enumerate(source_data, 1):
        trans_input = translator.make_input(i, line)
        logger.debug(" IN: %s", trans_input)
        tic = time.time()
        trans_output = translator.translate(trans_input)
        trans_wall_time = time.time() - tic
        total_time += trans_wall_time
        logger.debug("OUT: %s", trans_output)
        logger.debug("OUT: time=%.2f", trans_wall_time)
        output_handler.handle(trans_input, trans_output)
    return i, total_time
Пример #2
0
def translate(output_handler: sockeye.output_handler.OutputHandler,
              source_data: Iterable[str],
              translator: sockeye.inference.Translator,
              chunk_id: int = 0) -> float:
    """
    Translates each line from source_data, calling output handler after translating a batch.

    :param output_handler: A handler that will be called once with the output of each translation.
    :param source_data: A enumerable list of source sentences that will be translated.
    :param translator: The translator that will be used for each line of input.
    :param chunk_id: Global id of the chunk.
    :return: Total time taken.
    """

    tic = time.time()
    trans_inputs = [
        translator.make_input(i, line)
        for i, line in enumerate(source_data, chunk_id + 1)
    ]
    trans_outputs = translator.translate(trans_inputs)
    total_time = time.time() - tic
    batch_time = total_time / len(trans_inputs)
    for trans_input, trans_output in zip(trans_inputs, trans_outputs):
        output_handler.handle(trans_input, trans_output, batch_time)
    return total_time
Пример #3
0
def translate_lines(
        output_handler: sockeye.output_handler.OutputHandler,
        source_data: Iterable[str],
        translator: sockeye.inference.Translator) -> Tuple[int, float]:
    """
    Translates each line from source_data, calling output handler for each result.

    :param output_handler: A handler that will be called once with the output of each translation.
    :param source_data: A enumerable list of source sentences that will be translated.
    :param translator: The translator that will be used for each line of input.
    :param edge_vocab: Edge label vocabulary for graphs.
    :return: The number of lines translated, and the total time taken.
    """

    i = 0
    total_time = 0.0
    for i, line in enumerate(source_data, 1):
        tic = time.time()

        #########
        # GCN - This is an ugly hack: we concatenate the surface sentence
        # and the graph into a single line
        surface, graph = line.split('\t')
        #########

        trans_input = translator.make_input(i, surface, graph,
                                            translator.vocab_edge)
        logger.debug(" IN: %s", trans_input)
        trans_output = translator.translate(trans_input)
        trans_wall_time = time.time() - tic
        total_time += trans_wall_time
        logger.debug("OUT: %s", trans_output)
        logger.debug("OUT: time=%.2f", trans_wall_time)
        output_handler.handle(trans_input, trans_output, trans_wall_time)
    return i, total_time
Пример #4
0
def main():
    params = argparse.ArgumentParser(
        description='Translate from STDIN to STDOUT')
    params = arguments.add_inference_args(params)
    params = arguments.add_device_args(params)
    args = params.parse_args()

    logger = setup_main_logger(__name__, file_logging=False)

    assert args.beam_size > 0, "Beam size must be 1 or greater."
    if args.checkpoints is not None:
        assert len(args.checkpoints) == len(
            args.models), "must provide checkpoints for each model"

    logger.info("Command: %s", " ".join(sys.argv))
    logger.info("Arguments: %s", args)

    output_stream = sys.stdout
    output_handler = sockeye.output_handler.get_output_handler(
        args.output_type, output_stream, args.align_plot_prefix,
        args.sure_align_threshold)

    with ExitStack() as exit_stack:
        if args.use_cpu:
            context = mx.cpu()
        else:
            num_gpus = get_num_gpus()
            assert num_gpus > 0, "No GPUs found, consider running on the CPU with --use-cpu " \
                                 "(note: check depends on nvidia-smi and this could also mean that the nvidia-smi " \
                                 "binary isn't on the path)."
            assert len(
                args.device_ids) == 1, "cannot run on multiple devices for now"
            gpu_id = args.device_ids[0]
            if gpu_id < 0:
                # get a gpu id automatically:
                gpu_id = exit_stack.enter_context(acquire_gpu())
            context = mx.gpu(gpu_id)

        translator = sockeye.inference.Translator(
            context, args.ensemble_mode,
            *sockeye.inference.load_models(context, args.max_input_len,
                                           args.beam_size, args.models,
                                           args.checkpoints,
                                           args.softmax_temperature))
        total_time = 0
        i = 0
        for i, line in enumerate(sys.stdin, 1):
            trans_input = translator.make_input(i, line)
            logger.debug(" IN: %s", trans_input)

            tic = time.time()
            trans_output = translator.translate(trans_input)
            trans_wall_time = time.time() - tic
            total_time += trans_wall_time

            logger.debug("OUT: %s", trans_output)
            logger.debug("OUT: time=%.2f", trans_wall_time)

            output_handler.handle(trans_input, trans_output)

        if i != 0:
            logger.info(
                "Processed %d lines. Total time: %.4f sec/sent: %.4f sent/sec: %.4f",
                i, total_time, total_time / i, i / total_time)
        else:
            logger.info("Processed 0 lines.")