Esempio n. 1
0
def train_tvae(seq_file, train_list, val_list,
               decoder_file, qy_file, qz_file,
               epochs, batch_size,
               preproc_file, output_path,
               num_samples_y, num_samples_z,
               px_form, qy_form, qz_form,
               min_kl, **kwargs):

    set_float_cpu(float_keras())
    
    sr_args = SR.filter_args(**kwargs)
    sr_val_args = SR.filter_val_args(**kwargs)
    opt_args = KOF.filter_args(**kwargs)
    cb_args = KCF.filter_args(**kwargs)
    
    if preproc_file is not None:
        preproc = TransformList.load(preproc_file)
    else:
        preproc = None

    sr = SR(seq_file, train_list, batch_size=batch_size,
                  preproc=preproc, **sr_args)
    max_length = sr.max_batch_seq_length
    gen_val = None
    if val_list is not None:
        sr_val = SR(seq_file, val_list, batch_size=batch_size,
                    preproc=preproc,
                    shuffle_seqs=False,
                    seq_split_mode='sequential', seq_split_overlap=0,
                    reset_rng=True,
                    **sr_val_args)
        max_length = max(max_length, sr_val.max_batch_seq_length)
        gen_val = data_generator(sr_val, max_length)

    gen_train = data_generator(sr, max_length)
    
            
    t1 = time.time()
    decoder = load_model_arch(decoder_file)
    qy = load_model_arch(qy_file)


    if qz_file is None:
        vae = TVAEY(qy, decoder, px_cond_form=px_form,
                    qy_form=qy_form, min_kl=min_kl)
        vae.build(num_samples=num_samples_y, 
                  max_seq_length = max_length)
    else:
        qz = load_model_arch(qz_file)
        vae = TVAEYZ(qy, qz, decoder, px_cond_form=px_form,
                   qy_form=qy_form, qz_form=qz_form, min_kl=min_kl)
        vae.build(num_samples_y=num_samples_y, num_samples_z=num_samples_z,
                  max_seq_length = max_length)
    logging.info(time.time()-t1)
    
    cb = KCF.create_callbacks(vae, output_path, **cb_args)
    opt = KOF.create_optimizer(**opt_args)

    h = vae.fit_generator(gen_train, x_val=gen_val,
                          steps_per_epoch=sr.num_batches,
                          validation_steps=sr_val.num_batches,
                          optimizer=opt, epochs=epochs,
                          callbacks=cb, max_q_size=10)

    # if vae.x_chol is not None:
    #     x_chol = np.array(K.eval(vae.x_chol))
    #     logging.info(x_chol[:4,:4])
        
    
    logging.info('Train elapsed time: %.2f' % (time.time() - t1))
    
    vae.save(output_path + '/model')
Esempio n. 2
0
def train_tvae(seq_file, train_list, val_list, gmm_file, decoder_file, qy_file,
               qz_file, init_path, epochs, batch_size, preproc_file,
               output_path, num_samples_y, num_samples_z, px_form, qy_form,
               qz_form, min_kl, **kwargs):

    set_float_cpu(float_keras())

    sr_args = SR.filter_args(**kwargs)
    sr_val_args = SR.filter_val_args(**kwargs)
    opt_args = KOF.filter_args(**kwargs)
    cb_args = KCF.filter_args(**kwargs)

    if preproc_file is not None:
        preproc = TransformList.load(preproc_file)
    else:
        preproc = None

    gmm = DiagGMM.load_from_kaldi(gmm_file)

    sr = SR(seq_file,
            train_list,
            batch_size=batch_size,
            preproc=preproc,
            **sr_args)
    max_length = sr.max_batch_seq_length
    gen_val = None
    if val_list is not None:
        sr_val = SR(seq_file,
                    val_list,
                    batch_size=batch_size,
                    preproc=preproc,
                    shuffle_seqs=False,
                    seq_split_mode='sequential',
                    seq_split_overlap=0,
                    reset_rng=True,
                    **sr_val_args)
        max_length = max(max_length, sr_val.max_batch_seq_length)
        gen_val = data_generator(sr_val, gmm, max_length)

    gen_train = data_generator(sr, gmm, max_length)

    t1 = time.time()

    if init_path is None:
        decoder = load_model_arch(decoder_file)
        qy = load_model_arch(qy_file)

        # if qz_file is None:
        #     vae = TVAEY(qy, decoder, px_cond_form=px_form,
        #                 qy_form=qy_form, min_kl=min_kl)
        #     vae.build(num_samples=num_samples_y,
        #               max_seq_length = max_length)
        # else:
        qz = load_model_arch(qz_file)
        vae = TVAEYZ(qy,
                     qz,
                     decoder,
                     px_cond_form=px_form,
                     qy_form=qy_form,
                     qz_form=qz_form,
                     min_kl=min_kl)
    else:
        vae = TVAEYZ.load(init_path)

    vae.build(num_samples_y=num_samples_y,
              num_samples_z=num_samples_z,
              max_seq_length=max_length)
    logging.info(time.time() - t1)

    cb = KCF.create_callbacks(vae, output_path, **cb_args)
    opt = KOF.create_optimizer(**opt_args)

    h = vae.fit_generator(gen_train,
                          x_val=gen_val,
                          steps_per_epoch=sr.num_batches,
                          validation_steps=sr_val.num_batches,
                          optimizer=opt,
                          epochs=epochs,
                          callbacks=cb,
                          max_queue_size=10)

    # if vae.x_chol is not None:
    #     x_chol = np.array(K.eval(vae.x_chol))
    #     logging.info(x_chol[:4,:4])

    logging.info('Train elapsed time: %.2f' % (time.time() - t1))

    vae.save(output_path + '/model')
    sr_val.reset()
    y_val, sy_val, z_val, srz_val = vae.encoder_net.predict_generator(
        gen_val, steps=400)

    from scipy import linalg as la
    yy = y_val - np.mean(y_val, axis=0)
    cy = np.dot(yy.T, yy) / yy.shape[0]
    l, v = la.eigh(cy)
    np.savetxt(output_path + '/l1.txt', l)

    sr_val.reset()
    y_val2, sy_val2 = vae.qy_net.predict_generator(gen_val, steps=400)
    yy = y_val2 - np.mean(y_val, axis=0)
    cy = np.dot(yy.T, yy) / yy.shape[0]
    l, v = la.eigh(cy)
    np.savetxt(output_path + '/l2.txt', l)

    logging.info(y_val - y_val2)