Пример #1
0
def run_model(mode, path, in_file, o_file):
    global feature, encoder, indp, crf, mldecoder, rltrain, f_opt, e_opt, i_opt, c_opt, m_opt, r_opt

    cfg = Configuration()

    #General mode has two values: 'train' or 'test'
    cfg.mode = mode

    #Set Random Seeds
    random.seed(cfg.seed)
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)
    if hasCuda:
        torch.cuda.manual_seed_all(cfg.seed)

    #Load Embeddings
    load_embeddings(cfg)

    #Only for testing
    if mode == 'test': cfg.test_raw = in_file

    #Construct models
    feature = Feature(cfg)
    if cfg.model_type == 'AC-RNN':
        f_opt = optim.SGD(ifilter(lambda p: p.requires_grad,
                                  feature.parameters()),
                          lr=cfg.actor_step_size)
    else:
        f_opt = optim.Adam(ifilter(lambda p: p.requires_grad,
                                   feature.parameters()),
                           lr=cfg.learning_rate)

    if hasCuda: feature.cuda()

    encoder = Encoder(cfg)
    if cfg.model_type == 'AC-RNN':
        e_opt = optim.SGD(ifilter(lambda p: p.requires_grad,
                                  encoder.parameters()),
                          lr=cfg.actor_step_size)
    else:
        e_opt = optim.Adam(ifilter(lambda p: p.requires_grad,
                                   encoder.parameters()),
                           lr=cfg.learning_rate)
    if hasCuda: encoder.cuda()

    if cfg.model_type == 'INDP':
        indp = INDP(cfg)
        i_opt = optim.Adam(ifilter(lambda p: p.requires_grad,
                                   indp.parameters()),
                           lr=cfg.learning_rate)
        if hasCuda: indp.cuda()

    elif cfg.model_type == 'CRF':
        crf = CRF(cfg)
        c_opt = optim.Adam(ifilter(lambda p: p.requires_grad,
                                   crf.parameters()),
                           lr=cfg.learning_rate)
        if hasCuda: crf.cuda()

    elif cfg.model_type == 'TF-RNN':
        mldecoder = MLDecoder(cfg)
        m_opt = optim.Adam(ifilter(lambda p: p.requires_grad,
                                   mldecoder.parameters()),
                           lr=cfg.learning_rate)
        if hasCuda: mldecoder.cuda()
        cfg.mldecoder_type = 'TF'

    elif cfg.model_type == 'SS-RNN':
        mldecoder = MLDecoder(cfg)
        m_opt = optim.Adam(ifilter(lambda p: p.requires_grad,
                                   mldecoder.parameters()),
                           lr=cfg.learning_rate)
        if hasCuda: mldecoder.cuda()
        cfg.mldecoder_type = 'SS'

    elif cfg.model_type == 'AC-RNN':
        mldecoder = MLDecoder(cfg)
        m_opt = optim.SGD(ifilter(lambda p: p.requires_grad,
                                  mldecoder.parameters()),
                          lr=cfg.actor_step_size)
        if hasCuda: mldecoder.cuda()
        cfg.mldecoder_type = 'TF'
        rltrain = RLTrain(cfg)
        r_opt = optim.Adam(ifilter(lambda p: p.requires_grad,
                                   rltrain.parameters()),
                           lr=cfg.learning_rate,
                           weight_decay=0.001)
        if hasCuda: rltrain.cuda()
        cfg.rltrain_type = 'AC'
        #For RL, the network should be pre-trained with teacher forced ML decoder.
        feature.load_state_dict(torch.load(path + 'TF-RNN' + '_feature'))
        encoder.load_state_dict(torch.load(path + 'TF-RNN' + '_encoder'))
        mldecoder.load_state_dict(torch.load(path + 'TF-RNN' + '_predictor'))

    if mode == 'train':
        o_file = './temp.predicted_' + cfg.model_type
        best_val_cost = float('inf')
        best_val_epoch = 0
        first_start = time.time()
        epoch = 0
        while (epoch < cfg.max_epochs):
            print
            print 'Model:{} | Epoch:{}'.format(cfg.model_type, epoch)

            if cfg.model_type == 'SS-RNN':
                #Specify the decaying schedule for sampling probability.
                #inverse sigmoid schedule:
                cfg.sampling_p = float(
                    cfg.k) / float(cfg.k + np.exp(float(epoch) / cfg.k))

            start = time.time()
            run_epoch(cfg)
            print '\nValidation:'
            predict(cfg, o_file)
            val_cost = 100 - evaluate(cfg, cfg.dev_ref, o_file)
            print 'Validation score:{}'.format(100 - val_cost)
            if val_cost < best_val_cost:
                best_val_cost = val_cost
                best_val_epoch = epoch
                torch.save(feature.state_dict(),
                           path + cfg.model_type + '_feature')
                torch.save(encoder.state_dict(),
                           path + cfg.model_type + '_encoder')
                if cfg.model_type == 'INDP':
                    torch.save(indp.state_dict(),
                               path + cfg.model_type + '_predictor')
                elif cfg.model_type == 'CRF':
                    torch.save(crf.state_dict(),
                               path + cfg.model_type + '_predictor')
                elif cfg.model_type == 'TF-RNN' or cfg.model_type == 'SS-RNN':
                    torch.save(mldecoder.state_dict(),
                               path + cfg.model_type + '_predictor')
                elif cfg.model_type == 'AC-RNN':
                    torch.save(mldecoder.state_dict(),
                               path + cfg.model_type + '_predictor')
                    torch.save(rltrain.state_dict(),
                               path + cfg.model_type + '_critic')

            #For early stopping
            if epoch - best_val_epoch > cfg.early_stopping:
                break
                ###

            print 'Epoch training time:{} seconds'.format(time.time() - start)
            epoch += 1

        print 'Total training time:{} seconds'.format(time.time() -
                                                      first_start)

    elif mode == 'test':
        cfg.batch_size = 256
        feature.load_state_dict(torch.load(path + cfg.model_type + '_feature'))
        encoder.load_state_dict(torch.load(path + cfg.model_type + '_encoder'))
        if cfg.model_type == 'INDP':
            indp.load_state_dict(
                torch.load(path + cfg.model_type + '_predictor'))
        elif cfg.model_type == 'CRF':
            crf.load_state_dict(
                torch.load(path + cfg.model_type + '_predictor'))
        elif cfg.model_type == 'TF-RNN' or cfg.model_type == 'SS-RNN':
            mldecoder.load_state_dict(
                torch.load(path + cfg.model_type + '_predictor'))
        elif cfg.model_type == 'AC-RNN':
            mldecoder.load_state_dict(
                torch.load(path + cfg.model_type + '_predictor'))
            rltrain.load_state_dict(
                torch.load(path + cfg.model_type + '_critic'))

        print
        print 'Model:{} Predicting'.format(cfg.model_type)
        start = time.time()
        predict(cfg, o_file)
        print 'Total prediction time:{} seconds'.format(time.time() - start)
    return
Пример #2
0
def train(args):
    # initalize dataset
    with Timed('Loading dataset'):
        ds = tiny_words(max_text_length=hp.max_text_length,
                        max_audio_length=hp.max_audio_length,
                        max_dataset_size=args.data_size)

    # initialize model
    with Timed('Initializing model.'):
        encoder = Encoder(ds.lang.num_chars,
                          hp.embedding_dim,
                          hp.encoder_bank_k,
                          hp.encoder_bank_ck,
                          hp.encoder_proj_dims,
                          hp.encoder_highway_layers,
                          hp.encoder_highway_units,
                          hp.encoder_gru_units,
                          dropout=hp.dropout,
                          use_cuda=hp.use_cuda)

        decoder = AttnDecoder(hp.max_text_length,
                              hp.attn_gru_hidden_size,
                              hp.n_mels,
                              hp.rf,
                              hp.decoder_gru_hidden_size,
                              hp.decoder_gru_layers,
                              dropout=hp.dropout,
                              use_cuda=hp.use_cuda)

        postnet = PostNet(hp.n_mels,
                          1 + hp.n_fft // 2,
                          hp.post_bank_k,
                          hp.post_bank_ck,
                          hp.post_proj_dims,
                          hp.post_highway_layers,
                          hp.post_highway_units,
                          hp.post_gru_units,
                          use_cuda=hp.use_cuda)

        if args.multi_gpus:
            all_devices = list(range(torch.cuda.device_count()))
            encoder = nn.DataParallel(encoder, device_ids=all_devices)
            decoder = nn.DataParallel(decoder, device_ids=all_devices)
            postnet = nn.DataParallel(postnet, device_ids=all_devices)

        if hp.use_cuda:
            encoder.cuda()
            decoder.cuda()
            postnet.cuda()

        # initialize optimizers and criterion
        all_paramters = (list(encoder.parameters()) +
                         list(decoder.parameters()) +
                         list(postnet.parameters()))
        optimizer = optim.Adam(all_paramters, lr=hp.lr)
        criterion = nn.L1Loss()

        # configuring traingin
        print_every = 100
        save_every = 1000

        # Keep track of time elapsed and running averages
        start = time.time()
        print_loss_total = 0  # Reset every print_every

    for epoch in range(1, hp.n_epochs + 1):

        # get training data for this cycle
        mels, mags, indexed_texts = ds.next_batch(hp.batch_size)

        mels_v = Variable(torch.from_numpy(mels).float())
        mags_v = Variable(torch.from_numpy(mags).float())
        texts_v = Variable(torch.from_numpy(indexed_texts))

        if hp.use_cuda:
            mels_v = mels_v.cuda()
            mags_v = mags_v.cuda()
            texts_v = texts_v.cuda()

        loss = train_batch(mels_v,
                           mags_v,
                           texts_v,
                           encoder,
                           decoder,
                           postnet,
                           optimizer,
                           criterion,
                           multi_gpus=args.multi_gpus)

        # Keep track of loss
        print_loss_total += loss

        if epoch == 0:
            continue

        if epoch % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print_summary = '%s (%d %d%%) %.4f' % \
                (time_since(start, epoch / hp.n_epochs),
                 epoch, epoch / hp.n_epochs * 100, print_loss_avg)
            print(print_summary)

        if epoch % save_every == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'encoder': encoder.state_dict(),
                'decoder': decoder.state_dict(),
                'postnet': postnet.state_dict(),
                'optimizer': optimizer.state_dict(),
            })
Пример #3
0
def inference(checkpoint_file, text):
    ds = tiny_words(max_text_length=hp.max_text_length,
                    max_audio_length=hp.max_audio_length,
                    max_dataset_size=args.data_size)

    print(ds.texts)

    # prepare input
    indexes = indexes_from_text(ds.lang, text)
    indexes.append(EOT_token)
    padded_indexes = pad_indexes(indexes, hp.max_text_length, PAD_token)
    texts_v = Variable(torch.from_numpy(padded_indexes))
    texts_v = texts_v.unsqueeze(0)

    if hp.use_cuda:
        texts_v = texts_v.cuda()

    encoder = Encoder(ds.lang.num_chars,
                      hp.embedding_dim,
                      hp.encoder_bank_k,
                      hp.encoder_bank_ck,
                      hp.encoder_proj_dims,
                      hp.encoder_highway_layers,
                      hp.encoder_highway_units,
                      hp.encoder_gru_units,
                      dropout=hp.dropout,
                      use_cuda=hp.use_cuda)

    decoder = AttnDecoder(hp.max_text_length,
                          hp.attn_gru_hidden_size,
                          hp.n_mels,
                          hp.rf,
                          hp.decoder_gru_hidden_size,
                          hp.decoder_gru_layers,
                          dropout=hp.dropout,
                          use_cuda=hp.use_cuda)

    postnet = PostNet(hp.n_mels,
                      1 + hp.n_fft // 2,
                      hp.post_bank_k,
                      hp.post_bank_ck,
                      hp.post_proj_dims,
                      hp.post_highway_layers,
                      hp.post_highway_units,
                      hp.post_gru_units,
                      use_cuda=hp.use_cuda)

    encoder.eval()
    decoder.eval()
    postnet.eval()

    if hp.use_cuda:
        encoder.cuda()
        decoder.cuda()
        postnet.cuda()

    # load model
    checkpoint = torch.load(checkpoint_file)
    encoder.load_state_dict(checkpoint['encoder'])
    decoder.load_state_dict(checkpoint['decoder'])
    postnet.load_state_dict(checkpoint['postnet'])

    encoder_out = encoder(texts_v)

    # Prepare input and output variables
    GO_frame = np.zeros((1, hp.n_mels))
    decoder_in = Variable(torch.from_numpy(GO_frame).float())
    if hp.use_cuda:
        decoder_in = decoder_in.cuda()
    h, hs = decoder.init_hiddens(1)

    decoder_outs = []
    for t in range(int(hp.max_audio_length / hp.rf)):
        decoder_out, h, hs, _ = decoder(decoder_in, h, hs, encoder_out)
        decoder_outs.append(decoder_out)
        # use predict
        decoder_in = decoder_out[:, -1, :].contiguous()

    # (batch_size, T, n_mels)
    decoder_outs = torch.cat(decoder_outs, 1)

    # postnet
    post_out = postnet(decoder_outs)
    s = post_out[0].cpu().data.numpy()

    print("Recontructing wav...")
    s = np.where(s < 0, 0, s)
    wav = spectrogram2wav(s**hp.power)
    # wav = griffinlim(s**hp.power)
    write("demo.wav", hp.sr, wav)