def update_vocab(self, output_dir, src_vocab, tgt_vocab, shared_vocab, new_src_vocab, new_tgt_vocab, new_shared_vocab, mode, init, checkpoint_path=None): """Runs the update vocab ... Args: ... """ checkpoint.update_vocab( checkpoint_path, output_dir, src_vocab, tgt_vocab, shared_vocab, new_src_vocab, new_tgt_vocab, new_shared_vocab, mode, init, session_config=tf.ConfigProto(device_count={"GPU": 0}))
def train(self, config, src_file, tgt_file, src_vocab_info, tgt_vocab_info, align_file=None, model_path=None, gpuid=0): if src_vocab_info['changed'] or tgt_vocab_info['changed']: model_path = checkpoint.update_vocab( model_path, os.path.join(self._output_dir, 'new_vocab_checkpoint'), src_vocab_info['model'], tgt_vocab_info['model'], new_src_vocab=src_vocab_info['current'] if src_vocab_info['changed'] else None, new_tgt_vocab=tgt_vocab_info['current'] if tgt_vocab_info['changed'] else None, mode='replace', session_config=tf.ConfigProto(device_count={'GPU': 0})) model_dir, model = self._load_model( model_type=config['options'].get('model_type'), model_file=config['options'].get('model'), model_path=model_path) run_config = copy.deepcopy(config['options'].get('config', {})) run_config['model_dir'] = model_dir if 'data' not in run_config: run_config['data'] = {} if 'train' not in run_config: run_config['train'] = {} run_config['data']['source_words_vocabulary'] = src_vocab_info[ 'current'] run_config['data']['target_words_vocabulary'] = tgt_vocab_info[ 'current'] run_config['data']['train_features_file'] = src_file run_config['data']['train_labels_file'] = tgt_file if align_file is not None and os.path.exists(align_file): run_config['data']['train_alignments'] = align_file if "params" not in run_config: run_config["params"] = {} if "guided_alignment_type" not in run_config["params"]: run_config["params"]["guided_alignment_type"] = "ce" if 'train_steps' not in run_config['train']: run_config['train']['single_pass'] = True run_config['train']['train_steps'] = None if 'sample_buffer_size' not in run_config['train']: run_config['train']['sample_buffer_size'] = -1 if 'average_last_checkpoints' not in run_config['train']: run_config['train']['average_last_checkpoints'] = 0 runner = onmt.Runner(model, run_config, num_devices=utils.count_devices(gpuid), auto_config=config['options'].get( 'auto_config', False)) output_dir = runner.train() if output_dir != model_dir: shutil.copy(os.path.join(model_dir, "model_description.py"), output_dir) return self._list_model_files(output_dir)
def main(): tf.logging.set_verbosity(tf.logging.INFO) parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--model_dir", required=True, help="The model directory containing the checkpoints.") parser.add_argument( "--output_dir", required=True, help="The output directory where the updated checkpoint will be saved." ) parser.add_argument("--src_vocab", required=True, help="Path to the current source vocabulary.") parser.add_argument("--tgt_vocab", required=True, help="Path to the current target vocabulary.") parser.add_argument("--new_src_vocab", default=None, help="Path to the new source vocabulary.") parser.add_argument("--new_tgt_vocab", default=None, help="Path to the new target vocabulary.") parser.add_argument("--mode", default="merge", choices=["merge", "replace"], help="Vocabulary update mode.") args = parser.parse_args() checkpoint.update_vocab( args.model_dir, args.output_dir, args.src_vocab, args.tgt_vocab, new_src_vocab=args.new_src_vocab, new_tgt_vocab=args.new_tgt_vocab, mode=args.mode, session_config=tf.ConfigProto(device_count={"GPU": 0}))
def train(self, config, src_file, tgt_file, src_vocab_info, tgt_vocab_info, model_path=None, gpuid=0): if src_vocab_info['changed'] or tgt_vocab_info['changed']: model_path = checkpoint.update_vocab( model_path, os.path.join(self._output_dir, 'new_vocab_checkpoint'), src_vocab_info['model'], tgt_vocab_info['model'], new_src_vocab=src_vocab_info['current'] if src_vocab_info['changed'] else None, new_tgt_vocab=tgt_vocab_info['current'] if tgt_vocab_info['changed'] else None, mode='replace', session_config=tf.ConfigProto(device_count={'GPU': 0})) model_dir, model = self._load_model( model_type=config['options'].get('model_type'), model_file=config['options'].get('model'), model_path=model_path) run_config = copy.deepcopy(config['options'].get('config', {})) run_config['model_dir'] = model_dir if 'data' not in run_config: run_config['data'] = {} if 'train' not in run_config: run_config['train'] = {} run_config['data']['source_words_vocabulary'] = src_vocab_info[ 'current'] run_config['data']['target_words_vocabulary'] = tgt_vocab_info[ 'current'] run_config['data']['train_features_file'] = src_file run_config['data']['train_labels_file'] = tgt_file if 'train_steps' not in run_config['train']: run_config['train']['single_pass'] = True run_config['train']['train_steps'] = None if 'sample_buffer_size' not in run_config['train']: run_config['train']['sample_buffer_size'] = -1 runner = onmt.Runner(model, run_config, num_devices=utils.count_devices(gpuid), auto_config=config['options'].get( 'auto_config', False)) runner.train() return self._list_model_files(model_dir)