from linguistic_style_transfer_pytorch.config import GeneralConfig, ModelConfig from linguistic_style_transfer_pytorch.data_loader import TextDataset from linguistic_style_transfer_pytorch.model import AdversarialVAE from tqdm import tqdm, trange import os import numpy as np import pickle use_cuda = True if torch.cuda.is_available() else False if __name__ == "__main__": mconfig = ModelConfig() gconfig = GeneralConfig() weights = torch.FloatTensor(np.load(gconfig.word_embedding_path)) model = AdversarialVAE(weight=weights) if use_cuda: model = model.cuda() #=============== Define dataloader ================# train_dataset = TextDataset(mode='train') train_dataloader = DataLoader(train_dataset, batch_size=mconfig.batch_size) content_discriminator_params, style_discriminator_params, vae_and_classifier_params = model.get_params( ) #============== Define optimizers ================# # content discriminator/adversary optimizer content_disc_opt = torch.optim.RMSprop(content_discriminator_params, lr=mconfig.content_adversary_lr) # style discriminaot/adversary optimizer style_disc_opt = torch.optim.RMSprop(style_discriminator_params, lr=mconfig.style_adversary_lr)
import torch import os import argparse import numpy as np import pickle from linguistic_style_transfer_pytorch.config import GeneralConfig from linguistic_style_transfer_pytorch.model import AdversarialVAE gconfig = GeneralConfig() # load word embeddings weights = torch.FloatTensor(np.load(gconfig.word_embedding_path)) # load checkpoint model_checkpoint = torch.load('checkpoints/model_epoch_20.pt') # Load model model = AdversarialVAE(weights=weights) model.load_state_dict(model_checkpoint) model.eval() # Load average style embeddings with open(config.avg_style_emb_path, 'rb') as f: avg_style_embeddings = pickle.load(f) # set avg_style_emb attribute of the model model.avg_style_emb = avg_style_embeddings # load word2index with open(gconfig.w2i_file_path) as f: word2index = json.load(f) # load index2word with open(gconfig.i2w_file_path) as f: index2word = json.load(f) label2index = {'neg': 0, 'pos': 1} # Read input sentence
use_cuda = False device = torch.device('cpu') if torch.cuda.is_available(): use_cuda = True device = torch.device('cuda:0') print('using backend(', device, ')') if __name__ == "__main__": mconfig = ModelConfig() gconfig = GeneralConfig() weights = torch.tensor(np.load(gconfig.word_embedding_path), device=device, dtype=torch.float) model = AdversarialVAE(inference=False, weight=weights, device=device) if use_cuda: model = model.cuda() #=============== Define dataloader ================# train_dataset = TextDataset(mode='train') train_dataloader = DataLoader(train_dataset, batch_size=mconfig.batch_size, drop_last=True, pin_memory=True) content_discriminator_params, style_discriminator_params, vae_and_classifier_params = model.get_params( ) #============== Define optimizers ================# # content discriminator/adversary optimizer content_disc_opt = torch.optim.RMSprop(content_discriminator_params, lr=mconfig.content_adversary_lr)