예제 #1
0
    def update_vocab(self,
                     output_dir,
                     src_vocab,
                     tgt_vocab,
                     shared_vocab,
                     new_src_vocab,
                     new_tgt_vocab,
                     new_shared_vocab,
                     mode,
                     init,
                     checkpoint_path=None):
        """Runs the update vocab ...

    Args:
     ...
    """
        checkpoint.update_vocab(
            checkpoint_path,
            output_dir,
            src_vocab,
            tgt_vocab,
            shared_vocab,
            new_src_vocab,
            new_tgt_vocab,
            new_shared_vocab,
            mode,
            init,
            session_config=tf.ConfigProto(device_count={"GPU": 0}))
예제 #2
0
 def train(self,
           config,
           src_file,
           tgt_file,
           src_vocab_info,
           tgt_vocab_info,
           align_file=None,
           model_path=None,
           gpuid=0):
     if src_vocab_info['changed'] or tgt_vocab_info['changed']:
         model_path = checkpoint.update_vocab(
             model_path,
             os.path.join(self._output_dir, 'new_vocab_checkpoint'),
             src_vocab_info['model'],
             tgt_vocab_info['model'],
             new_src_vocab=src_vocab_info['current']
             if src_vocab_info['changed'] else None,
             new_tgt_vocab=tgt_vocab_info['current']
             if tgt_vocab_info['changed'] else None,
             mode='replace',
             session_config=tf.ConfigProto(device_count={'GPU': 0}))
     model_dir, model = self._load_model(
         model_type=config['options'].get('model_type'),
         model_file=config['options'].get('model'),
         model_path=model_path)
     run_config = copy.deepcopy(config['options'].get('config', {}))
     run_config['model_dir'] = model_dir
     if 'data' not in run_config:
         run_config['data'] = {}
     if 'train' not in run_config:
         run_config['train'] = {}
     run_config['data']['source_words_vocabulary'] = src_vocab_info[
         'current']
     run_config['data']['target_words_vocabulary'] = tgt_vocab_info[
         'current']
     run_config['data']['train_features_file'] = src_file
     run_config['data']['train_labels_file'] = tgt_file
     if align_file is not None and os.path.exists(align_file):
         run_config['data']['train_alignments'] = align_file
         if "params" not in run_config:
             run_config["params"] = {}
         if "guided_alignment_type" not in run_config["params"]:
             run_config["params"]["guided_alignment_type"] = "ce"
     if 'train_steps' not in run_config['train']:
         run_config['train']['single_pass'] = True
         run_config['train']['train_steps'] = None
     if 'sample_buffer_size' not in run_config['train']:
         run_config['train']['sample_buffer_size'] = -1
     if 'average_last_checkpoints' not in run_config['train']:
         run_config['train']['average_last_checkpoints'] = 0
     runner = onmt.Runner(model,
                          run_config,
                          num_devices=utils.count_devices(gpuid),
                          auto_config=config['options'].get(
                              'auto_config', False))
     output_dir = runner.train()
     if output_dir != model_dir:
         shutil.copy(os.path.join(model_dir, "model_description.py"),
                     output_dir)
     return self._list_model_files(output_dir)
예제 #3
0
def main():
    tf.logging.set_verbosity(tf.logging.INFO)

    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--model_dir",
                        required=True,
                        help="The model directory containing the checkpoints.")
    parser.add_argument(
        "--output_dir",
        required=True,
        help="The output directory where the updated checkpoint will be saved."
    )
    parser.add_argument("--src_vocab",
                        required=True,
                        help="Path to the current source vocabulary.")
    parser.add_argument("--tgt_vocab",
                        required=True,
                        help="Path to the current target vocabulary.")
    parser.add_argument("--new_src_vocab",
                        default=None,
                        help="Path to the new source vocabulary.")
    parser.add_argument("--new_tgt_vocab",
                        default=None,
                        help="Path to the new target vocabulary.")
    parser.add_argument("--mode",
                        default="merge",
                        choices=["merge", "replace"],
                        help="Vocabulary update mode.")
    args = parser.parse_args()
    checkpoint.update_vocab(
        args.model_dir,
        args.output_dir,
        args.src_vocab,
        args.tgt_vocab,
        new_src_vocab=args.new_src_vocab,
        new_tgt_vocab=args.new_tgt_vocab,
        mode=args.mode,
        session_config=tf.ConfigProto(device_count={"GPU": 0}))
예제 #4
0
 def train(self,
           config,
           src_file,
           tgt_file,
           src_vocab_info,
           tgt_vocab_info,
           model_path=None,
           gpuid=0):
     if src_vocab_info['changed'] or tgt_vocab_info['changed']:
         model_path = checkpoint.update_vocab(
             model_path,
             os.path.join(self._output_dir, 'new_vocab_checkpoint'),
             src_vocab_info['model'],
             tgt_vocab_info['model'],
             new_src_vocab=src_vocab_info['current']
             if src_vocab_info['changed'] else None,
             new_tgt_vocab=tgt_vocab_info['current']
             if tgt_vocab_info['changed'] else None,
             mode='replace',
             session_config=tf.ConfigProto(device_count={'GPU': 0}))
     model_dir, model = self._load_model(
         model_type=config['options'].get('model_type'),
         model_file=config['options'].get('model'),
         model_path=model_path)
     run_config = copy.deepcopy(config['options'].get('config', {}))
     run_config['model_dir'] = model_dir
     if 'data' not in run_config:
         run_config['data'] = {}
     if 'train' not in run_config:
         run_config['train'] = {}
     run_config['data']['source_words_vocabulary'] = src_vocab_info[
         'current']
     run_config['data']['target_words_vocabulary'] = tgt_vocab_info[
         'current']
     run_config['data']['train_features_file'] = src_file
     run_config['data']['train_labels_file'] = tgt_file
     if 'train_steps' not in run_config['train']:
         run_config['train']['single_pass'] = True
         run_config['train']['train_steps'] = None
     if 'sample_buffer_size' not in run_config['train']:
         run_config['train']['sample_buffer_size'] = -1
     runner = onmt.Runner(model,
                          run_config,
                          num_devices=utils.count_devices(gpuid),
                          auto_config=config['options'].get(
                              'auto_config', False))
     runner.train()
     return self._list_model_files(model_dir)