def load_seq2seq(filename, map_location=None): """ Loads a model from a filename. (This function requires that the model was saved with save_model utils function) :param filename: Filename to be used. """ from seq2seq import Encoder, Decoder model_dict = torch.load( filename, map_location=None ) if map_location is not None else torch.load(filename) encoder = Encoder(**model_dict["encoder_init_args"]) decoder = Decoder(**model_dict["decoder_init_args"]) encoder.load_state_dict(model_dict["encoder_state"]) decoder.load_state_dict(model_dict["decoder_state"]) return encoder, decoder
def load_model(path='files/model_De', device=device): checkpoint = torch.load(path, map_location=device) in_lang, out_lang, pairs = prepareData('En', 'De') in_lang = checkpoint['in_lang_class'] out_lang = checkpoint['out_lang_class'] hidden_size = checkpoint['hidden_size'] encoder = Encoder(in_lang.n_words, hidden_size).to(device) decoder = Decoder(hidden_size, out_lang.n_words, dropout_p=0.1).to(device) encoder.load_state_dict(checkpoint['encoder_state_dict']) decoder.load_state_dict(checkpoint['decoder_state_dict']) return encoder, decoder, in_lang, out_lang
def train(article, title, word2idx, target2idx, source_lengths, target_lengths, args, val_article=None, val_title=None, val_source_lengths=None, val_target_lengths=None): if not os.path.exists('./temp/x.pkl'): size_of_val = int(len(article) * 0.05) val_article, val_title, val_source_lengths, val_target_lengths = \ utils.sampling(article, title, source_lengths, target_lengths, size_of_val) utils.save_everything(article, title, source_lengths, target_lengths, val_article, val_title, val_source_lengths, val_target_lengths, word2idx) size_of_val = len(val_article) batch_size = args.batch train_size = len(article) val_size = len(val_article) max_a = max(source_lengths) max_t = max(target_lengths) print("source vocab size:", len(word2idx)) print("target vocab size:", len(target2idx)) print("max a:{}, max t:{}".format(max_a, max_t)) print("train_size:", train_size) print("val size:", val_size) print("batch_size:", batch_size) print("-" * 30) use_coverage = False encoder = Encoder(len(word2idx)) decoder = Decoder(len(target2idx), 50) if os.path.exists('decoder_model'): encoder.load_state_dict(torch.load('encoder_model')) decoder.load_state_dict(torch.load('decoder_model')) optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=0.001) n_epoch = 5 print("Making word index and extend vocab") #article, article_tar, title, ext_vocab_all, ext_count = indexing_word(article, title, word2idx, target2idx) #article = to_tensor(article) #article_extend = to_tensor(article_extend) #title = to_tensor(title) print("preprocess done") if args.use_cuda: encoder.cuda() decoder.cuda() print("start training") for epoch in range(n_epoch): total_loss = 0 batch_n = int(train_size / batch_size) if epoch > 0: use_coverage = True for b in range(batch_n): # initialization batch_x = article[b * batch_size:(b + 1) * batch_size] batch_y = title[b * batch_size:(b + 1) * batch_size] #batch_x_ext = article_extend[b*batch_size: (b+1)*batch_size] batch_x, batch_x_ext, batch_y, extend_vocab, extend_lengths = \ utils.batch_index(batch_x, batch_y, word2idx, target2idx) if args.use_cuda: batch_x = batch_x.cuda() batch_y = batch_y.cuda() batch_x_ext = batch_x_ext.cuda() x_lengths = source_lengths[b * batch_size:(b + 1) * batch_size] y_lengths = target_lengths[b * batch_size:(b + 1) * batch_size] # work around to deal with length pack = pack_padded_sequence(batch_x_ext, x_lengths, batch_first=True) batch_x_ext_var, _ = pad_packed_sequence(pack, batch_first=True) current_loss = train_on_batch(encoder, decoder, optimizer, batch_x, batch_y, x_lengths, y_lengths, word2idx, target2idx, batch_x_ext_var, extend_lengths, use_coverage) batch_x = batch_x.cpu() batch_y = batch_y.cpu() batch_x_ext = batch_x_ext.cpu() print('epoch:{}/{}, batch:{}/{}, loss:{}'.format( epoch + 1, n_epoch, b + 1, batch_n, current_loss)) if (b + 1) % args.show_decode == 0: torch.save(encoder.state_dict(), 'encoder_model') torch.save(decoder.state_dict(), 'decoder_model') batch_x_val, batch_x_ext_val, batch_y_val, extend_vocab, extend_lengths = \ utils.batch_index(val_article, val_title, word2idx, target2idx) for i in range(1): idx = np.random.randint(0, val_size) decode.beam_search(encoder, decoder, batch_x_val[idx].unsqueeze(0), batch_y_val[idx].unsqueeze(0), word2idx, target2idx, batch_x_ext_val[idx], extend_lengths[idx], extend_vocab[idx]) batch_x_val = batch_x_val.cpu() batch_y_val = batch_y_val.cpu() batch_x_ext_val = batch_x_ext_val.cpu() total_loss += current_loss print('-' * 30) print() print("training finished")
de_embedding_dim = parameters_dict['de_embedding_dim'] hidden_dim = parameters_dict['hidden_dim'] num_layers = parameters_dict['num_layers'] bidirectional = parameters_dict['bidirectional'] use_lstm = parameters_dict['use_lstm'] use_cuda = False batch_size = 1 dropout_p = 0.0 encoder = Encoder(en_embedding_dim, hidden_dim, en_vocab.n_items, num_layers, dropout_p, bidirectional, use_lstm, use_cuda) decoder = Decoder(de_embedding_dim, hidden_dim, de_vocab.n_items, num_layers, dropout_p, bidirectional, use_lstm, use_cuda) encoder.load_state_dict(torch.load(encoder_model_file, map_location='cpu')) decoder.load_state_dict(torch.load(decoder_model_file, map_location='cpu')) encoder.eval() decoder.eval() f_en_test = open('input.txt', 'r', encoding='utf-8') f_de_pred = open('output.txt', 'w', encoding='utf-8') while True: en_sent = f_en_test.readline() if not en_sent: break sent = en_sent.strip() en_seq, en_seq_len = trf.trans_input(sent)
class Text2song(object): def __init__(self): def Load_Vocab(file): with open(file, 'rb') as fd: _vocab = pickle.load(fd) return _vocab def Load_Parameters(file): with open(file, 'rb') as fd: parameters_dict = pickle.load(fd) return parameters_dict torch.manual_seed(1) torch.set_num_threads(4) en_vocab_dur_file = './en_vocab_dur.pkl' de_vocab_dur_file = './de_vocab_dur.pkl' encoder_dur_model_file = './encoder_dur.10.pt' decoder_dur_model_file = './decoder_dur.10.pt' en_vocab_key_file = './en_vocab.pkl' de_vocab_key_file = './de_vocab.pkl' encoder_key_model_file = './encoder.10.pt' decoder_key_model_file = './decoder.10.pt' hyper_parameters_file = './parameters_dict.pkl' self.en_vocab_key = Load_Vocab(en_vocab_key_file) self.de_vocab_key = Load_Vocab(de_vocab_key_file) self.en_vocab_dur = Load_Vocab(en_vocab_dur_file) self.de_vocab_dur = Load_Vocab(de_vocab_dur_file) self.trf_key = Transfrom(self.en_vocab_key) self.trf_dur = Transfrom(self.en_vocab_dur) self.parameters_dict = Load_Parameters(hyper_parameters_file) en_embedding_dim = self.parameters_dict['en_embedding_dim'] de_embedding_dim = self.parameters_dict['de_embedding_dim'] hidden_dim = self.parameters_dict['hidden_dim'] num_layers = self.parameters_dict['num_layers'] bidirectional = self.parameters_dict['bidirectional'] use_lstm = self.parameters_dict['use_lstm'] self.use_cuda_dur = self.use_cuda_key = False batch_size = 1 dropout_p = 0.0 self.encoder_key = Encoder(en_embedding_dim, hidden_dim, self.en_vocab_key.n_items, num_layers, dropout_p, bidirectional, use_lstm, self.use_cuda_key) self.decoder_key = Decoder(de_embedding_dim, hidden_dim, self.de_vocab_key.n_items, num_layers, dropout_p, bidirectional, use_lstm, self.use_cuda_key) self.encoder_dur = Encoder(en_embedding_dim, hidden_dim, self.en_vocab_dur.n_items, num_layers, dropout_p, bidirectional, use_lstm, self.use_cuda_dur) self.decoder_dur = Decoder(de_embedding_dim, hidden_dim, self.de_vocab_dur.n_items, num_layers, dropout_p, bidirectional, use_lstm, self.use_cuda_dur) self.encoder_key.load_state_dict( torch.load(encoder_key_model_file, map_location='cpu')) self.decoder_key.load_state_dict( torch.load(decoder_key_model_file, map_location='cpu')) self.encoder_dur.load_state_dict( torch.load(encoder_dur_model_file, map_location='cpu')) self.decoder_dur.load_state_dict( torch.load(decoder_dur_model_file, map_location='cpu')) self.encoder_key.eval() self.decoder_key.eval() self.encoder_dur.eval() self.decoder_dur.eval() """ __init__ return the parameters: {self.trf_dur,self.trf_key; self.encoder_dur,self.encoder_key; self.decoder_dur,self.decoder_key; self.en_vocab_dur,self.en_vocab_key; self.de_vocab_dur,self.de_vocab_key; self.use_cuda_dur,self,self.use_cuda_key.}""" def get_song(self, lyric): def stop_before_eos(li, length): if '_EOS_' in li: i = li.index('_EOS_') li = li[:i] while (li.__len__() < length): li.append(li[-1]) return li def important_function_in_while_loop(trf, sent, encoder, decoder, de_vocab, use_cuda, en_sent): en_seq, en_seq_len = trf.trans_input(sent) en_seq = torch.LongTensor(en_seq) encoder_input = en_seq encoder_output, encoder_state = encoder(encoder_input, en_seq_len) # initial decoder hidden decoder_state = decoder.init_state(encoder_state) # Start decoding decoder_inputs = torch.LongTensor([de_vocab.item2index['_START_']]) pred_char = '' if use_cuda: decoder_inputs = decoder_inputs.cuda() decoder_outputs, decoder_state = decoder(decoder_inputs, encoder_output, decoder_state) max_len = len(en_sent.split()) return (pred_char, encoder_output, decoder_outputs, decoder_state, max_len) f_en_test = io.StringIO(lyric) pred_list = [] while True: en_sent = f_en_test.readline() if not en_sent: break sent = en_sent.strip() pred_sent_dur = [] pred_sent_key = [] pred_char_key, encoder_output_key, decoder_outputs_key, decoder_state_key, max_len_key = \ important_function_in_while_loop(self.trf_key, sent, self.encoder_key, self.decoder_key, self.de_vocab_key, self.use_cuda_key, en_sent) pred_char_dur, encoder_output_dur, decoder_outputs_dur, decoder_state_dur, max_len_dur = \ important_function_in_while_loop(self.trf_dur, sent, self.encoder_dur, self.decoder_dur, self.de_vocab_dur, self.use_cuda_dur, en_sent) # Greedy search while pred_char_key != '_EOS_' and pred_char_dur != '_EOS_': log_prob_key, v_idx_key = decoder_outputs_key.data.topk(1) pred_char_key = self.de_vocab_key.index2item[v_idx_key.item()] pred_sent_key.append(pred_char_key) log_prob_dur, v_idx_dur = decoder_outputs_dur.data.topk(1) pred_char_dur = self.de_vocab_dur.index2item[v_idx_dur.item()] pred_sent_dur.append(pred_char_dur) if (len(pred_sent_dur) > max_len_dur or len(pred_sent_dur) > max_len_key): break decoder_inputs_dur = torch.LongTensor([v_idx_dur.item()]) if self.use_cuda_dur: decoder_inputs_dur = decoder_inputs_dur.cuda() decoder_outputs_dur, decoder_state_dur = self.decoder_dur( decoder_inputs_dur, encoder_output_dur, decoder_state_dur) decoder_inputs_key = torch.LongTensor([v_idx_key.item()]) if self.use_cuda_key: decoder_inputs_key = decoder_inputs_key.cuda() decoder_outputs_key, decoder_state_key = self.decoder_key( decoder_inputs_key, encoder_output_key, decoder_state_key) length = len(sent.split()) pred_list.append({ 'lyrics': sent, 'key': stop_before_eos(pred_sent_key, length), 'duration': stop_before_eos(pred_sent_dur, length) }) # pred_list.append({'lyrics': sent, 'key': pred_sent_key, 'duration': pred_sent_dur}) return pred_list
parameters_dict['use_lstm'] = use_lstm with open('parameters_dict.pkl', 'wb') as fd: pickle.dump(parameters_dict, fd) batch_total = sum([1 for _ in pl.gen_pairs(batch_size)]) ones_matrix = autograd.Variable(torch.ones(1, de_vocab.n_items)) encoder = Encoder(en_embedding_dim, hidden_dim, en_vocab.n_items, num_layers, dropout_p, bidirectional, use_lstm, use_cuda) decoder = Decoder(de_embedding_dim, hidden_dim, de_vocab.n_items, num_layers, dropout_p, bidirectional, use_lstm, use_cuda) encoder_model_file = 'encoder_rev.7.pt' decoder_model_file = 'decoder_rev.7.pt' encoder.load_state_dict(torch.load(encoder_model_file)) decoder.load_state_dict(torch.load(decoder_model_file)) ''' #Load Pre-trained Embedding model_file = 'bi_gru.100.100.2.pt' if model_file != '' : model.load_state_dict(torch.load(model_file)) else: model.load_pre_train_emb('cityu_training.char.emb.npy', 'cityu_training.char.dict', vocab) ''' loss_function = nn.NLLLoss(reduction = 'sum', ignore_index = de_vocab.item2index['_PAD_']) en_optimizer = optim.Adam(encoder.parameters(), lr = 1e-3, weight_decay = 0) de_optimizer = optim.Adam(decoder.parameters(), lr = 1e-3, weight_decay = 0) if use_cuda: encoder = encoder.cuda() decoder = decoder.cuda()