示例#1
0
def load_seq2seq(filename, map_location=None):
    """
    Loads a model from a filename. (This function requires that the model was saved with save_model utils function)
    :param filename: Filename to be used.
    """
    from seq2seq import Encoder, Decoder
    model_dict = torch.load(
        filename, map_location=None
    ) if map_location is not None else torch.load(filename)
    encoder = Encoder(**model_dict["encoder_init_args"])
    decoder = Decoder(**model_dict["decoder_init_args"])
    encoder.load_state_dict(model_dict["encoder_state"])
    decoder.load_state_dict(model_dict["decoder_state"])
    return encoder, decoder
示例#2
0
def load_model(path='files/model_De', device=device):
    checkpoint = torch.load(path, map_location=device)

    in_lang, out_lang, pairs = prepareData('En', 'De')
    in_lang = checkpoint['in_lang_class']
    out_lang = checkpoint['out_lang_class']

    hidden_size = checkpoint['hidden_size']

    encoder = Encoder(in_lang.n_words, hidden_size).to(device)
    decoder = Decoder(hidden_size, out_lang.n_words, dropout_p=0.1).to(device)

    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_state_dict'])

    return encoder, decoder, in_lang, out_lang
def train(article,
          title,
          word2idx,
          target2idx,
          source_lengths,
          target_lengths,
          args,
          val_article=None,
          val_title=None,
          val_source_lengths=None,
          val_target_lengths=None):

    if not os.path.exists('./temp/x.pkl'):
        size_of_val = int(len(article) * 0.05)
        val_article, val_title, val_source_lengths, val_target_lengths = \
            utils.sampling(article, title, source_lengths, target_lengths, size_of_val)

        utils.save_everything(article, title, source_lengths, target_lengths,
                              val_article, val_title, val_source_lengths,
                              val_target_lengths, word2idx)

    size_of_val = len(val_article)
    batch_size = args.batch
    train_size = len(article)
    val_size = len(val_article)
    max_a = max(source_lengths)
    max_t = max(target_lengths)
    print("source vocab size:", len(word2idx))
    print("target vocab size:", len(target2idx))
    print("max a:{}, max t:{}".format(max_a, max_t))
    print("train_size:", train_size)
    print("val size:", val_size)
    print("batch_size:", batch_size)
    print("-" * 30)
    use_coverage = False

    encoder = Encoder(len(word2idx))
    decoder = Decoder(len(target2idx), 50)
    if os.path.exists('decoder_model'):
        encoder.load_state_dict(torch.load('encoder_model'))
        decoder.load_state_dict(torch.load('decoder_model'))

    optimizer = torch.optim.Adam(list(encoder.parameters()) +
                                 list(decoder.parameters()),
                                 lr=0.001)
    n_epoch = 5
    print("Making word index and extend vocab")
    #article, article_tar, title, ext_vocab_all, ext_count = indexing_word(article, title, word2idx, target2idx)
    #article = to_tensor(article)
    #article_extend = to_tensor(article_extend)
    #title = to_tensor(title)
    print("preprocess done")

    if args.use_cuda:
        encoder.cuda()
        decoder.cuda()

    print("start training")
    for epoch in range(n_epoch):
        total_loss = 0
        batch_n = int(train_size / batch_size)
        if epoch > 0:
            use_coverage = True
        for b in range(batch_n):
            # initialization
            batch_x = article[b * batch_size:(b + 1) * batch_size]
            batch_y = title[b * batch_size:(b + 1) * batch_size]
            #batch_x_ext = article_extend[b*batch_size: (b+1)*batch_size]
            batch_x, batch_x_ext, batch_y, extend_vocab, extend_lengths = \
                utils.batch_index(batch_x, batch_y, word2idx, target2idx)

            if args.use_cuda:
                batch_x = batch_x.cuda()
                batch_y = batch_y.cuda()
                batch_x_ext = batch_x_ext.cuda()
            x_lengths = source_lengths[b * batch_size:(b + 1) * batch_size]
            y_lengths = target_lengths[b * batch_size:(b + 1) * batch_size]

            # work around to deal with length
            pack = pack_padded_sequence(batch_x_ext,
                                        x_lengths,
                                        batch_first=True)
            batch_x_ext_var, _ = pad_packed_sequence(pack, batch_first=True)
            current_loss = train_on_batch(encoder, decoder, optimizer, batch_x,
                                          batch_y, x_lengths, y_lengths,
                                          word2idx, target2idx,
                                          batch_x_ext_var, extend_lengths,
                                          use_coverage)

            batch_x = batch_x.cpu()
            batch_y = batch_y.cpu()
            batch_x_ext = batch_x_ext.cpu()

            print('epoch:{}/{}, batch:{}/{}, loss:{}'.format(
                epoch + 1, n_epoch, b + 1, batch_n, current_loss))
            if (b + 1) % args.show_decode == 0:
                torch.save(encoder.state_dict(), 'encoder_model')
                torch.save(decoder.state_dict(), 'decoder_model')
                batch_x_val, batch_x_ext_val, batch_y_val, extend_vocab, extend_lengths = \
                    utils.batch_index(val_article, val_title, word2idx, target2idx)
                for i in range(1):
                    idx = np.random.randint(0, val_size)
                    decode.beam_search(encoder, decoder,
                                       batch_x_val[idx].unsqueeze(0),
                                       batch_y_val[idx].unsqueeze(0), word2idx,
                                       target2idx, batch_x_ext_val[idx],
                                       extend_lengths[idx], extend_vocab[idx])

                batch_x_val = batch_x_val.cpu()
                batch_y_val = batch_y_val.cpu()
                batch_x_ext_val = batch_x_ext_val.cpu()

            total_loss += current_loss
            print('-' * 30)

    print()
    print("training finished")
示例#4
0
    en_embedding_dim = parameters_dict['en_embedding_dim']
    de_embedding_dim = parameters_dict['de_embedding_dim']
    hidden_dim = parameters_dict['hidden_dim']
    num_layers = parameters_dict['num_layers']
    bidirectional = parameters_dict['bidirectional']
    use_lstm = parameters_dict['use_lstm']
    use_cuda = False
    batch_size = 1
    dropout_p = 0.0

    encoder = Encoder(en_embedding_dim, hidden_dim, en_vocab.n_items,
                      num_layers, dropout_p, bidirectional, use_lstm, use_cuda)
    decoder = Decoder(de_embedding_dim, hidden_dim, de_vocab.n_items,
                      num_layers, dropout_p, bidirectional, use_lstm, use_cuda)

    encoder.load_state_dict(torch.load(encoder_model_file, map_location='cpu'))
    decoder.load_state_dict(torch.load(decoder_model_file, map_location='cpu'))

    encoder.eval()
    decoder.eval()

    f_en_test = open('input.txt', 'r', encoding='utf-8')
    f_de_pred = open('output.txt', 'w', encoding='utf-8')

    while True:
        en_sent = f_en_test.readline()

        if not en_sent: break

        sent = en_sent.strip()
        en_seq, en_seq_len = trf.trans_input(sent)
示例#5
0
class Text2song(object):
    def __init__(self):
        def Load_Vocab(file):
            with open(file, 'rb') as fd:
                _vocab = pickle.load(fd)
            return _vocab

        def Load_Parameters(file):
            with open(file, 'rb') as fd:
                parameters_dict = pickle.load(fd)
            return parameters_dict

        torch.manual_seed(1)
        torch.set_num_threads(4)
        en_vocab_dur_file = './en_vocab_dur.pkl'
        de_vocab_dur_file = './de_vocab_dur.pkl'

        encoder_dur_model_file = './encoder_dur.10.pt'
        decoder_dur_model_file = './decoder_dur.10.pt'

        en_vocab_key_file = './en_vocab.pkl'
        de_vocab_key_file = './de_vocab.pkl'

        encoder_key_model_file = './encoder.10.pt'
        decoder_key_model_file = './decoder.10.pt'
        hyper_parameters_file = './parameters_dict.pkl'
        self.en_vocab_key = Load_Vocab(en_vocab_key_file)
        self.de_vocab_key = Load_Vocab(de_vocab_key_file)

        self.en_vocab_dur = Load_Vocab(en_vocab_dur_file)
        self.de_vocab_dur = Load_Vocab(de_vocab_dur_file)

        self.trf_key = Transfrom(self.en_vocab_key)
        self.trf_dur = Transfrom(self.en_vocab_dur)

        self.parameters_dict = Load_Parameters(hyper_parameters_file)

        en_embedding_dim = self.parameters_dict['en_embedding_dim']
        de_embedding_dim = self.parameters_dict['de_embedding_dim']
        hidden_dim = self.parameters_dict['hidden_dim']
        num_layers = self.parameters_dict['num_layers']
        bidirectional = self.parameters_dict['bidirectional']
        use_lstm = self.parameters_dict['use_lstm']
        self.use_cuda_dur = self.use_cuda_key = False
        batch_size = 1
        dropout_p = 0.0

        self.encoder_key = Encoder(en_embedding_dim, hidden_dim,
                                   self.en_vocab_key.n_items, num_layers,
                                   dropout_p, bidirectional, use_lstm,
                                   self.use_cuda_key)
        self.decoder_key = Decoder(de_embedding_dim, hidden_dim,
                                   self.de_vocab_key.n_items, num_layers,
                                   dropout_p, bidirectional, use_lstm,
                                   self.use_cuda_key)
        self.encoder_dur = Encoder(en_embedding_dim, hidden_dim,
                                   self.en_vocab_dur.n_items, num_layers,
                                   dropout_p, bidirectional, use_lstm,
                                   self.use_cuda_dur)
        self.decoder_dur = Decoder(de_embedding_dim, hidden_dim,
                                   self.de_vocab_dur.n_items, num_layers,
                                   dropout_p, bidirectional, use_lstm,
                                   self.use_cuda_dur)

        self.encoder_key.load_state_dict(
            torch.load(encoder_key_model_file, map_location='cpu'))
        self.decoder_key.load_state_dict(
            torch.load(decoder_key_model_file, map_location='cpu'))
        self.encoder_dur.load_state_dict(
            torch.load(encoder_dur_model_file, map_location='cpu'))
        self.decoder_dur.load_state_dict(
            torch.load(decoder_dur_model_file, map_location='cpu'))

        self.encoder_key.eval()
        self.decoder_key.eval()
        self.encoder_dur.eval()
        self.decoder_dur.eval()
        """ __init__ return the parameters: {self.trf_dur,self.trf_key;
                                            self.encoder_dur,self.encoder_key;
                                            self.decoder_dur,self.decoder_key;
                                            self.en_vocab_dur,self.en_vocab_key;
                                            self.de_vocab_dur,self.de_vocab_key;
                                            self.use_cuda_dur,self,self.use_cuda_key.}"""

    def get_song(self, lyric):
        def stop_before_eos(li, length):
            if '_EOS_' in li:
                i = li.index('_EOS_')
                li = li[:i]
            while (li.__len__() < length):
                li.append(li[-1])
            return li

        def important_function_in_while_loop(trf, sent, encoder, decoder,
                                             de_vocab, use_cuda, en_sent):
            en_seq, en_seq_len = trf.trans_input(sent)

            en_seq = torch.LongTensor(en_seq)
            encoder_input = en_seq
            encoder_output, encoder_state = encoder(encoder_input, en_seq_len)

            # initial decoder hidden
            decoder_state = decoder.init_state(encoder_state)

            # Start decoding
            decoder_inputs = torch.LongTensor([de_vocab.item2index['_START_']])

            pred_char = ''

            if use_cuda: decoder_inputs = decoder_inputs.cuda()
            decoder_outputs, decoder_state = decoder(decoder_inputs,
                                                     encoder_output,
                                                     decoder_state)

            max_len = len(en_sent.split())

            return (pred_char, encoder_output, decoder_outputs, decoder_state,
                    max_len)

        f_en_test = io.StringIO(lyric)

        pred_list = []

        while True:
            en_sent = f_en_test.readline()

            if not en_sent: break

            sent = en_sent.strip()
            pred_sent_dur = []
            pred_sent_key = []
            pred_char_key, encoder_output_key, decoder_outputs_key, decoder_state_key, max_len_key = \
                important_function_in_while_loop(self.trf_key, sent, self.encoder_key, self.decoder_key, self.de_vocab_key, self.use_cuda_key,
                                                 en_sent)

            pred_char_dur, encoder_output_dur, decoder_outputs_dur, decoder_state_dur, max_len_dur = \
                important_function_in_while_loop(self.trf_dur, sent, self.encoder_dur, self.decoder_dur, self.de_vocab_dur, self.use_cuda_dur,
                                                 en_sent)

            # Greedy search
            while pred_char_key != '_EOS_' and pred_char_dur != '_EOS_':
                log_prob_key, v_idx_key = decoder_outputs_key.data.topk(1)
                pred_char_key = self.de_vocab_key.index2item[v_idx_key.item()]
                pred_sent_key.append(pred_char_key)

                log_prob_dur, v_idx_dur = decoder_outputs_dur.data.topk(1)
                pred_char_dur = self.de_vocab_dur.index2item[v_idx_dur.item()]
                pred_sent_dur.append(pred_char_dur)

                if (len(pred_sent_dur) > max_len_dur
                        or len(pred_sent_dur) > max_len_key):
                    break

                decoder_inputs_dur = torch.LongTensor([v_idx_dur.item()])
                if self.use_cuda_dur:
                    decoder_inputs_dur = decoder_inputs_dur.cuda()
                decoder_outputs_dur, decoder_state_dur = self.decoder_dur(
                    decoder_inputs_dur, encoder_output_dur, decoder_state_dur)

                decoder_inputs_key = torch.LongTensor([v_idx_key.item()])
                if self.use_cuda_key:
                    decoder_inputs_key = decoder_inputs_key.cuda()
                decoder_outputs_key, decoder_state_key = self.decoder_key(
                    decoder_inputs_key, encoder_output_key, decoder_state_key)
            length = len(sent.split())
            pred_list.append({
                'lyrics': sent,
                'key': stop_before_eos(pred_sent_key, length),
                'duration': stop_before_eos(pred_sent_dur, length)
            })
            # pred_list.append({'lyrics': sent, 'key': pred_sent_key, 'duration': pred_sent_dur})

        return pred_list
示例#6
0
    parameters_dict['bidirectional'] = bidirectional
    parameters_dict['use_lstm'] = use_lstm

    with open('parameters_dict.pkl', 'wb') as fd:
        pickle.dump(parameters_dict, fd)
    
    batch_total = sum([1 for _ in pl.gen_pairs(batch_size)])
    ones_matrix = autograd.Variable(torch.ones(1, de_vocab.n_items))
    
    encoder = Encoder(en_embedding_dim, hidden_dim, en_vocab.n_items, num_layers, dropout_p, bidirectional, use_lstm, use_cuda)
    decoder = Decoder(de_embedding_dim, hidden_dim, de_vocab.n_items, num_layers, dropout_p, bidirectional, use_lstm, use_cuda)

    
    encoder_model_file = 'encoder_rev.7.pt'
    decoder_model_file = 'decoder_rev.7.pt'
    encoder.load_state_dict(torch.load(encoder_model_file))
    decoder.load_state_dict(torch.load(decoder_model_file))

    '''
    #Load Pre-trained Embedding
    model_file = 'bi_gru.100.100.2.pt'
    if model_file != '' : model.load_state_dict(torch.load(model_file))
    else: model.load_pre_train_emb('cityu_training.char.emb.npy', 'cityu_training.char.dict', vocab)
    '''
    
    loss_function = nn.NLLLoss(reduction = 'sum', ignore_index = de_vocab.item2index['_PAD_'])
    en_optimizer = optim.Adam(encoder.parameters(), lr = 1e-3, weight_decay = 0)
    de_optimizer = optim.Adam(decoder.parameters(), lr = 1e-3, weight_decay = 0)
    
    if use_cuda:
        encoder = encoder.cuda()