예제 #1
0
    def load_statics(self):
        # Loading vocabulary
        if self.verbose:
            t0 = time.time()
            print(
                f"Loading vocabularies src={self.hparams.src} tgt={self.hparams.tgt}",
                file=sys.stderr)
        self.vocab_src, self.vocab_tgt = load_vocabularies(self.hparams)

        # Load pre/post processing models and configure a pipeline
        self.pipeline = TranslationEngine.make_pipeline(self.hparams)

        if self.verbose:
            print(
                f"Restoring model selected wrt {self.hparams.criterion} from {self.model_checkpoint}",
                file=sys.stderr)

        model, _, _, translate_fn = create_model(self.hparams, self.vocab_src,
                                                 self.vocab_tgt)

        if self.hparams.use_gpu:
            model.load_state_dict(torch.load(self.model_checkpoint))
        else:
            model.load_state_dict(
                torch.load(self.model_checkpoint, map_location='cpu'))

        self.model = model.to(self.device)
        self.translate_fn = translate_fn
        self.model.eval()
        if self.verbose:
            print("Done loading in %.2f seconds" % (time.time() - t0),
                  file=sys.stderr)
예제 #2
0
def create_vocab():

    # Load and print hyperparameters.
    hparams = Hyperparameters(check_required=True)
    print("\n==== Hyperparameters")
    hparams.print_values()

    # Load the data and print some statistics.
    vocab_src, vocab_tgt = load_vocabularies(hparams)
    if hparams.share_vocab:
        print("\n==== Vocabulary")
        vocab_src.print_statistics()
    else:
        print("\n==== Source vocabulary")
        vocab_src.print_statistics()
        print("\n==== Target vocabulary")
        vocab_tgt.print_statistics()

    # Create the output directory.
    out_dir = Path(hparams.output_dir)
    if not out_dir.exists():
        out_dir.mkdir()

    print(f"\nSaving vocabularies to {out_dir}...")
    vocab_src.save(out_dir / f"vocab.{hparams.src}")
    vocab_tgt.save(out_dir / f"vocab.{hparams.tgt}")
    hparams.vocab_prefix = out_dir / "vocab"
예제 #3
0
def main():

    # Load and print hyperparameters.
    hparams = Hyperparameters()
    print("\n==== Hyperparameters")
    hparams.print_values()

    # Load the data and print some statistics.
    vocab_src, vocab_tgt = load_vocabularies(hparams)
    if hparams.share_vocab:
        print("\n==== Vocabulary")
        vocab_src.print_statistics()
    else:
        print("\n==== Source vocabulary")
        vocab_src.print_statistics()
        print("\n==== Target vocabulary")
        vocab_tgt.print_statistics()
    train_data, val_data, _ = load_data(hparams,
                                        vocab_src=vocab_src,
                                        vocab_tgt=vocab_tgt)
    print("\n==== Data")
    print(f"Training data: {len(train_data):,} bilingual sentence pairs")
    print(f"Validation data: {len(val_data):,} bilingual sentence pairs")

    # Create the language model and load it onto the GPU if set to do so.
    model, train_fn, validate_fn, _ = create_model(hparams, vocab_src,
                                                   vocab_tgt)
    optimizers, lr_schedulers = construct_optimizers(
        hparams,
        gen_parameters=model.generative_parameters(),
        inf_z_parameters=model.inference_parameters(),
        lagrangian_parameters=model.lagrangian_parameters())
    device = torch.device("cuda:0") if hparams.use_gpu else torch.device("cpu")
    model = model.to(device)

    # Print information about the model.
    param_count_M = model_parameter_count(model) / 1e6
    print("\n==== Model")
    print("Short summary:")
    print(model)
    print("\nAll parameters:")
    for name, param in model.named_parameters():
        print(f"{name} -- {param.size()}")
    print(f"\nNumber of model parameters: {param_count_M:.2f} M")

    # Initialize the model parameters, or load a checkpoint.
    if hparams.model_checkpoint is None:
        print("\nInitializing parameters...")
        initialize_model(model,
                         vocab_tgt[PAD_TOKEN],
                         hparams.cell_type,
                         hparams.emb_init_scale,
                         verbose=True)
    else:
        print(
            f"\nRestoring model parameters from {hparams.model_checkpoint}...")
        model.load_state_dict(torch.load(hparams.model_checkpoint))

    # Create the output directories.
    out_dir = Path(hparams.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    if hparams.vocab_prefix is None:
        vocab_src.save(out_dir / f"vocab.{hparams.src}")
        vocab_tgt.save(out_dir / f"vocab.{hparams.tgt}")
        hparams.vocab_prefix = out_dir / "vocab"
    hparams.save(out_dir / "hparams")
    print("\n==== Output")
    print(f"Created output directory at {hparams.output_dir}")

    # Train the model.
    print("\n==== Starting training")
    print(f"Using device: {device}\n")
    train(model, optimizers, lr_schedulers, train_data, val_data, vocab_src,
          vocab_tgt, device, out_dir, train_fn, validate_fn, hparams)