예제 #1
0
def _get_parser():
    parser = ArgumentParser(description='train.py')
    parser.add_argument('--teacher_model_path',
                        action='store',
                        dest='teacher_model_path',
                        help='the path direct to the teacher model path')
    parser.add_argument("--word_sampling",
                        action="store",
                        default=False,
                        help="optional arg")

    opts.config_opts(parser)
    opts.model_opts(parser)
    opts.train_opts(parser)
    return parser
예제 #2
0
def _get_parser():
    parser = ArgumentParser(description='translate.py')
    parser.add_argument('--pivot_vocab', action='store', dest='pivot_vocab')
    opts.config_opts(parser)
    opts.translate_opts(parser)
    return parser
예제 #3
0
    def __init__(self, model_dir):

        # Model dir
        self._model_dir = model_dir

        # Get pt file
        model_files = []
        for file in os.listdir(f'{self._model_dir}/translation_model'):
            if file.endswith(".pt"):
                model_files.append(
                    os.path.join(model_dir, 'translation_model', file))
        if len(model_files) != 1:
            msg = f"Extended model {self._model_dir} sould have one .pt file. {len(model_files)}"
            raise ValueError(msg)
        model_file = model_files[0]

        # Load config from file
        config_path = os.path.join(model_dir, "config.json")

        # Load config from json file
        with open(config_path) as f:
            config = json.load(f)

        # Langs
        self._src_lang = config[self.SRC_LANG]
        self._tgt_lang = config[self.TGT_LANG]

        # Online learning
        if self.ONLINE_LEARNING in config:
            self._online_learning = _OnlineLearningConfig(
                config[self.ONLINE_LEARNING])
        else:
            self._online_learning = _OnlineLearningConfig()

        # Create a parser for train and translate options
        parser = ArgumentParser(description='pangeanmt options',
                                conflict_handler='resolve')

        # Parser for translate options
        _translate_opts(parser)

        # Parser for training options
        _train_opts(parser)

        # Parser for model options
        _model_opts(parser)

        # --src argument is not used
        parser.add_argument('--src',
                            '-src',
                            required=False,
                            help="This argument isn't used!")

        # --data argument is not used
        parser.add_argument('--data',
                            '-data',
                            required=False,
                            help="This argument isn't used!")

        # --seed Overwrite default
        parser.add_argument('--seed',
                            '-seed',
                            required=False,
                            default=829,
                            help="Seed")

        # Create opts from config
        args = ['--model', model_file]
        for k, v in config['opts'].items():
            args.append('--' + k)
            if v is not None:
                if k == 'model':
                    v = os.path.join(model_dir, v)
                args.append(str(v))
        self._opts = parser.parse_args(args)