def load_statics(self): # Loading vocabulary if self.verbose: t0 = time.time() print( f"Loading vocabularies src={self.hparams.src} tgt={self.hparams.tgt}", file=sys.stderr) self.vocab_src, self.vocab_tgt = load_vocabularies(self.hparams) # Load pre/post processing models and configure a pipeline self.pipeline = TranslationEngine.make_pipeline(self.hparams) if self.verbose: print( f"Restoring model selected wrt {self.hparams.criterion} from {self.model_checkpoint}", file=sys.stderr) model, _, _, translate_fn = create_model(self.hparams, self.vocab_src, self.vocab_tgt) if self.hparams.use_gpu: model.load_state_dict(torch.load(self.model_checkpoint)) else: model.load_state_dict( torch.load(self.model_checkpoint, map_location='cpu')) self.model = model.to(self.device) self.translate_fn = translate_fn self.model.eval() if self.verbose: print("Done loading in %.2f seconds" % (time.time() - t0), file=sys.stderr)
def create_vocab(): # Load and print hyperparameters. hparams = Hyperparameters(check_required=True) print("\n==== Hyperparameters") hparams.print_values() # Load the data and print some statistics. vocab_src, vocab_tgt = load_vocabularies(hparams) if hparams.share_vocab: print("\n==== Vocabulary") vocab_src.print_statistics() else: print("\n==== Source vocabulary") vocab_src.print_statistics() print("\n==== Target vocabulary") vocab_tgt.print_statistics() # Create the output directory. out_dir = Path(hparams.output_dir) if not out_dir.exists(): out_dir.mkdir() print(f"\nSaving vocabularies to {out_dir}...") vocab_src.save(out_dir / f"vocab.{hparams.src}") vocab_tgt.save(out_dir / f"vocab.{hparams.tgt}") hparams.vocab_prefix = out_dir / "vocab"
def main(): # Load and print hyperparameters. hparams = Hyperparameters() print("\n==== Hyperparameters") hparams.print_values() # Load the data and print some statistics. vocab_src, vocab_tgt = load_vocabularies(hparams) if hparams.share_vocab: print("\n==== Vocabulary") vocab_src.print_statistics() else: print("\n==== Source vocabulary") vocab_src.print_statistics() print("\n==== Target vocabulary") vocab_tgt.print_statistics() train_data, val_data, _ = load_data(hparams, vocab_src=vocab_src, vocab_tgt=vocab_tgt) print("\n==== Data") print(f"Training data: {len(train_data):,} bilingual sentence pairs") print(f"Validation data: {len(val_data):,} bilingual sentence pairs") # Create the language model and load it onto the GPU if set to do so. model, train_fn, validate_fn, _ = create_model(hparams, vocab_src, vocab_tgt) optimizers, lr_schedulers = construct_optimizers( hparams, gen_parameters=model.generative_parameters(), inf_z_parameters=model.inference_parameters(), lagrangian_parameters=model.lagrangian_parameters()) device = torch.device("cuda:0") if hparams.use_gpu else torch.device("cpu") model = model.to(device) # Print information about the model. param_count_M = model_parameter_count(model) / 1e6 print("\n==== Model") print("Short summary:") print(model) print("\nAll parameters:") for name, param in model.named_parameters(): print(f"{name} -- {param.size()}") print(f"\nNumber of model parameters: {param_count_M:.2f} M") # Initialize the model parameters, or load a checkpoint. if hparams.model_checkpoint is None: print("\nInitializing parameters...") initialize_model(model, vocab_tgt[PAD_TOKEN], hparams.cell_type, hparams.emb_init_scale, verbose=True) else: print( f"\nRestoring model parameters from {hparams.model_checkpoint}...") model.load_state_dict(torch.load(hparams.model_checkpoint)) # Create the output directories. out_dir = Path(hparams.output_dir) out_dir.mkdir(parents=True, exist_ok=True) if hparams.vocab_prefix is None: vocab_src.save(out_dir / f"vocab.{hparams.src}") vocab_tgt.save(out_dir / f"vocab.{hparams.tgt}") hparams.vocab_prefix = out_dir / "vocab" hparams.save(out_dir / "hparams") print("\n==== Output") print(f"Created output directory at {hparams.output_dir}") # Train the model. print("\n==== Starting training") print(f"Using device: {device}\n") train(model, optimizers, lr_schedulers, train_data, val_data, vocab_src, vocab_tgt, device, out_dir, train_fn, validate_fn, hparams)