Ejemplo n.º 1
0
def model_fn(model_dir):
    """Load the PyTorch model from the `model_dir` directory."""
    print("Loading model.")

    # First, load the parameters used to create the model.
    model_info = {}
    model_info_path = os.path.join(model_dir, 'model_info.pth')
    with open(model_info_path, 'rb') as f:
        model_info = torch.load(f)

    print("model_info: {}".format(model_info))

    # Determine the device and construct the model.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = RNNModel(model_info['vocab_size'], model_info['embedding_dim'],
                     model_info['hidden_dim'], model_info['n_layers'],
                     model_info['drop_rate'])

    # Load the stored model parameters.
    model_path = os.path.join(model_dir, 'model.pth')
    with open(model_path, 'rb') as f:
        model.load_state_dict(
            torch.load(f, map_location=lambda storage, loc: storage))

    # Load the saved word_dict.
    word_dict_path = os.path.join(model_dir, 'char_dict.pkl')
    with open(word_dict_path, 'rb') as f:
        model.char2int = pickle.load(f)

    word_dict_path = os.path.join(model_dir, 'int_dict.pkl')
    with open(word_dict_path, 'rb') as f:
        model.int2char = pickle.load(f)

    model.to(device).eval()

    print("Done loading model.")
    return model
Ejemplo n.º 2
0
    torch.manual_seed(args.seed)

    # Load the training data.
    train_loader, num_batches = _get_train_data_loader(args.batch_size,
                                                       args.max_len,
                                                       args.data_dir)
    # Calculate the validation batches
    val_batches = int(num_batches * args.val_frac)
    # Build the model.
    # Instantiate the model with hyperparameters
    model = RNNModel(args.vocab_size, args.vocab_size, args.hidden_dim,
                     args.n_layers).to(device)

    # Load the dictionaries
    with open(os.path.join(args.data_dir, "char_dict.pkl"), "rb") as f:
        model.char2int = pickle.load(f)

    with open(os.path.join(args.data_dir, "int_dict.pkl"), "rb") as f:
        model.int2char = pickle.load(f)

    print("Model loaded with embedding_dim {}, hidden_dim {}, vocab_size {}.".
          format(args.embedding_dim, args.hidden_dim, args.vocab_size))

    # Train the model.
    # Define Loss, Optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    train_main(model, optimizer, criterion, train_loader, num_batches,
               val_batches, args.batch_size, args.max_len, args.epochs,
               args.clip_norm, device)