def model_fn(model_dir): """Load the PyTorch model from the `model_dir` directory.""" print("Loading model.") # First, load the parameters used to create the model. model_info = {} model_info_path = os.path.join(model_dir, 'model_info.pth') with open(model_info_path, 'rb') as f: model_info = torch.load(f) print("model_info: {}".format(model_info)) # Determine the device and construct the model. device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = RNNModel(model_info['vocab_size'], model_info['embedding_dim'], model_info['hidden_dim'], model_info['n_layers'], model_info['drop_rate']) # Load the stored model parameters. model_path = os.path.join(model_dir, 'model.pth') with open(model_path, 'rb') as f: model.load_state_dict( torch.load(f, map_location=lambda storage, loc: storage)) # Load the saved word_dict. word_dict_path = os.path.join(model_dir, 'char_dict.pkl') with open(word_dict_path, 'rb') as f: model.char2int = pickle.load(f) word_dict_path = os.path.join(model_dir, 'int_dict.pkl') with open(word_dict_path, 'rb') as f: model.int2char = pickle.load(f) model.to(device).eval() print("Done loading model.") return model
train_loader, num_batches = _get_train_data_loader(args.batch_size, args.max_len, args.data_dir) # Calculate the validation batches val_batches = int(num_batches * args.val_frac) # Build the model. # Instantiate the model with hyperparameters model = RNNModel(args.vocab_size, args.vocab_size, args.hidden_dim, args.n_layers).to(device) # Load the dictionaries with open(os.path.join(args.data_dir, "char_dict.pkl"), "rb") as f: model.char2int = pickle.load(f) with open(os.path.join(args.data_dir, "int_dict.pkl"), "rb") as f: model.int2char = pickle.load(f) print("Model loaded with embedding_dim {}, hidden_dim {}, vocab_size {}.". format(args.embedding_dim, args.hidden_dim, args.vocab_size)) # Train the model. # Define Loss, Optimizer criterion = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) train_main(model, optimizer, criterion, train_loader, num_batches, val_batches, args.batch_size, args.max_len, args.epochs, args.clip_norm, device) # Save the parameters used to construct the model model_info_path = os.path.join(args.model_dir, 'model_info.pth')