예제 #1
0
    def __init__(self, model_dir=_default_model_dir, use_gpu=True, batch_size=256,
                 gpu_mem_frac=0.1, beam_width=10, num_top=1, maximum_iterations=1000,
                 cpu_threads=5, emb_activation=None):
        """Constructor for the inference model.

        Args:
            model_dir: Path to the model directory.
            use_gpu: Flag for GPU usage.
            batch_size: Number of samples to process per step.
            gpu_mem_frac: If GPU is used, what memory fraction should be used?
            beam_width:  Width of the the window used for the beam search decoder.
            num_top: Number of most probable sequnces as output of the beam search decoder.
            emb_activation: Activation function used in the bottleneck layer.
        Returns:
            None
        """
        self.num_top = num_top
        self.use_gpu = use_gpu
        parser = argparse.ArgumentParser()
        add_arguments(parser)
        flags = parser.parse_args([])
        flags.hparams_from_file = True
        flags.save_dir = model_dir
        self.hparams = create_hparams(flags)
        self.hparams.set_hparam("save_dir", model_dir)
        self.hparams.set_hparam("batch_size", batch_size)
        self.hparams.set_hparam("gpu_mem_frac", gpu_mem_frac)
        self.hparams.add_hparam("beam_width", beam_width)
        self.hparams.set_hparam("cpu_threads", cpu_threads)
        self.encode_model, self.decode_model = build_models(self.hparams,
                                                            modes=["ENCODE", "DECODE"])
        self.maximum_iterations = maximum_iterations
예제 #2
0
파일: train.py 프로젝트: zchwang/cddd
def main(unused_argv):
    """Main function that trains and evaluats the translation model"""
    hparams = create_hparams(FLAGS)
    os.environ['CUDA_VISIBLE_DEVICES'] = str(hparams.device)
    train_model, eval_model, encode_model = build_models(hparams)
    train_loop(train_model, eval_model, encode_model, hparams)