예제 #1
0
def main():
    global args
    taskname = args.taskname
    no_below = args.no_below
    no_above = args.no_above
    num_epochs = args.num_epochs
    n_topic = args.n_topic
    n_cpu = cpu_count()-2 if cpu_count()>2 else 2
    bkpt_continue = args.bkpt_continue
    use_tfidf = args.use_tfidf
    rebuild = args.rebuild
    dist = args.dist 
    batch_size = args.batch_size
    criterion = args.criterion
    auto_adj = args.auto_adj

    device = torch.device('cuda')
    docSet = DocDataset(taskname,no_below=no_below,no_above=no_above,rebuild=rebuild,use_tfidf=False)
    if auto_adj:
        no_above = docSet.topk_dfs(topk=20)
        docSet = DocDataset(taskname,no_below=no_below,no_above=no_above,rebuild=rebuild,use_tfidf=False)
    
    voc_size = docSet.vocabsize
    print('voc size:',voc_size)
    n_topic = args.n_topic
    model = WTM(bow_dim=voc_size,n_topic=n_topic,device=device,dist=dist,taskname=taskname,dropout=0.4)
    model.train(train_data=docSet,batch_size=batch_size,test_data=docSet,num_epochs=num_epochs,log_every=10,beta=1.0)
    model.evaluate(test_data=docSet)
    save_name = f'./ckpt/WTM_{taskname}_tp{n_topic}_{dist}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'
    torch.save(model.wae.state_dict(),save_name)
    txt_lst, embeds = model.get_embed(train_data=docSet, num=1000)
    torch.save({'txts':txt_lst,'embeds':embeds},'wtm_embeds.pkl')
예제 #2
0
def main():
    global args
    taskname = args.taskname
    no_below = args.no_below
    no_above = args.no_above
    num_epochs = args.num_epochs
    n_topic = args.n_topic
    n_cpu = cpu_count() - 2 if cpu_count() > 2 else 2
    bkpt_continue = args.bkpt_continue
    use_tfidf = args.use_tfidf
    rebuild = args.rebuild
    batch_size = args.batch_size
    criterion = args.criterion
    n_topic = args.n_topic
    auto_adj = args.auto_adj
    show_topics = args.show_topics

    device = torch.device('cpu')
    docSet = DocDataset(taskname,
                        no_below=no_below,
                        no_above=no_above,
                        rebuild=rebuild,
                        use_tfidf=False)
    if auto_adj:
        no_above = docSet.topk_dfs(topk=20)
        docSet = DocDataset(taskname,
                            no_below=no_below,
                            no_above=no_above,
                            rebuild=rebuild,
                            use_tfidf=False)

    voc_size = docSet.vocabsize
    print('voc size:', voc_size)
    model = GSM(bow_dim=voc_size,
                n_topic=n_topic,
                taskname=taskname,
                device=device)
    if bkpt_continue:
        path = os.listdir('./ckpt')[0]
        checkpoint = torch.load(os.path.join('./ckpt', path))
        model.vae.load_state_dict(checkpoint)
    model.train(train_data=docSet,
                batch_size=batch_size,
                test_data=docSet,
                num_epochs=num_epochs,
                log_every=10,
                beta=1.0,
                criterion=criterion)
    model.evaluate(test_data=docSet)

    if show_topics:
        with open(f'./result/{taskname}_ep{num_epochs}.txt', 'w') as f:
            for topic in model.show_topic_words():
                print(topic, file=f)

    save_name = f'./ckpt/GSM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'
    torch.save(model.vae.state_dict(), save_name)
예제 #3
0
def main():
    global args
    taskname = args.taskname
    no_below = args.no_below
    no_above = args.no_above
    num_epochs = args.num_epochs
    n_topic = args.n_topic
    n_cpu = cpu_count()-2 if cpu_count()>2 else 2
    bkpt_continue = args.bkpt_continue
    use_tfidf = args.use_tfidf
    rebuild = args.rebuild
    batch_size = args.batch_size
    criterion = args.criterion
    n_topic = args.n_topic
    emb_dim = args.emb_dim
    auto_adj = args.auto_adj

    device = torch.device('cuda')
    docSet = DocDataset(taskname,no_below=no_below,no_above=no_above,rebuild=rebuild,use_tfidf=False)
    if auto_adj:
        no_above = docSet.topk_dfs(topk=20)
        docSet = DocDataset(taskname,no_below=no_below,no_above=no_above,rebuild=rebuild,use_tfidf=False)
    
    voc_size = docSet.vocabsize
    print('voc size:',voc_size)
    model = ETM(bow_dim=voc_size,n_topic=n_topic,taskname=taskname,device=device,emb_dim=emb_dim) #TBD_fc1
    model.train(train_data=docSet,batch_size=batch_size,test_data=docSet,num_epochs=num_epochs,log_every=10,beta=1.0,criterion=criterion)
    model.evaluate(test_data=docSet)
    save_name = f'./ckpt/ETM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'
    torch.save(model.vae.state_dict(),save_name)
    topic_vecs = model.vae.alpha.weight.detach().cpu().numpy()
    word_vecs = model.vae.rho.weight.detach().cpu().numpy()
    print('topic_vecs.shape:',topic_vecs.shape)
    print('word_vecs.shape:',word_vecs.shape)
    vocab = np.array([t[0] for t in sorted(list(docSet.dictionary.token2id.items()),key=lambda x: x[1])]).reshape(-1,1)
    topic_ids = np.array([f'TP{i}' for i in range(n_topic)]).reshape(-1,1)
    word_vecs = np.concatenate([vocab,word_vecs],axis=1)
    topic_vecs = np.concatenate([topic_ids,topic_vecs],axis=1)
    #save_name_tp = f'./ckpt/TpVec_ETM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.emb'
    save_name_wd = f'./ckpt/WdVec_ETM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.emb'
    n_instances = word_vecs.shape[0]+topic_vecs.shape[0]
    with open(save_name_wd,'w',encoding='utf-8') as wfp:
        wfp.write(f'{n_instances} {emb_dim}\n')
        wfp.write('\n'.join([' '.join(e) for e in word_vecs]+[' '.join(e) for e in topic_vecs]))
    from gensim.models import KeyedVectors
    w2v = KeyedVectors.load_word2vec_format(save_name_wd,binary=False)
    w2v.save(save_name.split('.')[0]+'.w2v')
    print(w2v.vocab.keys())
    #w2v.most_similar('你好')
    for i in range(n_topic):
        print(f'Most similar to Topic {i}')
        print(w2v.most_similar(f'TP{i}'))
    txt_lst, embeds = model.get_embed(train_data=docSet, num=1000)
    with open('topic_dist_etm.txt','w',encoding='utf-8') as wfp:
        for t,e in zip(txt_lst,embeds):
            wfp.write(f'{e}:{t}\n')
    pickle.dump({'txts':txt_lst,'embeds':embeds},open('etm_embeds.pkl','wb'))
예제 #4
0
def main():
    global args
    taskname = args.taskname
    no_below = args.no_below
    no_above = args.no_above
    num_epochs = args.num_epochs
    n_topic = args.n_topic
    n_cpu = cpu_count() - 2 if cpu_count() > 2 else 2
    bkpt_continue = args.bkpt_continue
    use_tfidf = args.use_tfidf
    rebuild = args.rebuild
    batch_size = args.batch_size
    criterion = args.criterion
    auto_adj = args.auto_adj

    device = torch.device('cuda')
    docSet = DocDataset(taskname,
                        no_below=no_below,
                        no_above=no_above,
                        rebuild=rebuild,
                        use_tfidf=False)
    if auto_adj:
        no_above = docSet.topk_dfs(topk=20)
        docSet = DocDataset(taskname,
                            no_below=no_below,
                            no_above=no_above,
                            rebuild=rebuild,
                            use_tfidf=False)

    voc_size = docSet.vocabsize
    print('voc size:', voc_size)
    n_topic = args.n_topic

    model = GMNTM(bow_dim=voc_size,
                  n_topic=n_topic,
                  device=device,
                  taskname=taskname,
                  dropout=0.2)
    model.train(train_data=docSet,
                batch_size=batch_size,
                test_data=docSet,
                num_epochs=num_epochs,
                log_every=10,
                beta=1.0,
                criterion='bce_softmax')
    model.evaluate(test_data=docSet)
    save_name = f'./ckpt/GMNTM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'
    torch.save(model.vade.state_dict(), save_name)
    txt_lst, embeds = model.get_embed(train_data=docSet, num=1000)
    with open('topic_dist_gmntm.txt', 'w', encoding='utf-8') as wfp:
        for t, e in zip(txt_lst, embeds):
            wfp.write(f'{e}:{t}\n')
    pickle.dump({
        'txts': txt_lst,
        'embeds': embeds
    }, open('gmntm_embeds.pkl', 'wb'))
def main():
    global args
    taskname = args.taskname
    no_below = args.no_below
    no_above = args.no_above
    num_epochs = args.num_epochs
    n_topic = args.n_topic
    n_cpu = cpu_count() - 2 if cpu_count() > 2 else 2
    bkpt_continue = args.bkpt_continue
    use_tfidf = args.use_tfidf
    rebuild = args.rebuild
    dist = args.dist
    batch_size = args.batch_size
    criterion = args.criterion
    auto_adj = args.auto_adj

    device = torch.device('cuda')
    docSet = DocDataset(taskname,
                        no_below=no_below,
                        no_above=no_above,
                        rebuild=rebuild,
                        use_tfidf=True,
                        lang='en')
    if auto_adj:
        no_above = docSet.topk_dfs(topk=20)
        docSet = DocDataset(taskname,
                            no_below=no_below,
                            no_above=no_above,
                            rebuild=rebuild,
                            use_tfidf=False)
    voc_size = docSet.vocabsize

    n_topic = args.n_topic
    model = BATM(bow_dim=voc_size,
                 n_topic=n_topic,
                 device=device,
                 taskname=taskname)
    model.train(train_data=docSet,
                batch_size=batch_size,
                test_data=docSet,
                num_epochs=num_epochs,
                log_every=10,
                n_critic=10)
    model.evaluate(test_data=docSet)
    save_name = f'./ckpt/BATM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'
    torch.save(
        {
            'generator': model.generator.state_dict(),
            'encoder': model.encoder.state_dict(),
            'discriminator': model.discriminator.state_dict()
        }, save_name)
예제 #6
0
def main():
    global args
    taskname = args.taskname
    no_below = args.no_below
    no_above = args.no_above
    num_epochs = args.num_epochs
    n_topic = args.n_topic
    n_cpu = cpu_count() - 2 if cpu_count() > 2 else 2
    bkpt_continue = args.bkpt_continue
    use_tfidf = args.use_tfidf
    rebuild = args.rebuild
    batch_size = args.batch_size
    criterion = args.criterion
    auto_adj = args.auto_adj

    device = torch.device('cuda')
    docSet = DocDataset(taskname,
                        no_below=no_below,
                        no_above=no_above,
                        rebuild=rebuild,
                        use_tfidf=False)
    if auto_adj:
        no_above = docSet.topk_dfs(topk=20)
        docSet = DocDataset(taskname,
                            no_below=no_below,
                            no_above=no_above,
                            rebuild=rebuild,
                            use_tfidf=False)

    voc_size = docSet.vocabsize
    print('voc size:', voc_size)
    n_topic = args.n_topic

    model = GMNTM(bow_dim=voc_size,
                  n_topic=n_topic,
                  device=device,
                  taskname=taskname,
                  dropout=0.2)
    model.train(train_data=docSet,
                batch_size=batch_size,
                test_data=docSet,
                num_epochs=num_epochs,
                log_every=10,
                beta=1.0,
                criterion='bce_softmax')
    model.evaluate(test_data=docSet)
    save_name = f'./ckpt/GMNTM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'
    torch.save(model.vade.state_dict(), save_name)
예제 #7
0
def main():
    global args
    taskname = args.taskname
    no_below = args.no_below
    no_above = args.no_above
    num_epochs = args.num_epochs
    n_topic = args.n_topic
    n_cpu = cpu_count() - 2 if cpu_count() > 2 else 2
    bkpt_continue = args.bkpt_continue
    use_tfidf = args.use_tfidf
    rebuild = args.rebuild
    dist = args.dist
    batch_size = args.batch_size
    criterion = args.criterion
    auto_adj = args.auto_adj
    show_topics = args.show_topics

    device = torch.device('cpu')

    docSet = DocDataset(taskname,
                        no_below=no_below,
                        no_above=no_above,
                        rebuild=rebuild,
                        use_tfidf=True)
    if auto_adj:
        no_above = docSet.topk_dfs(topk=20)
        docSet = DocDataset(taskname,
                            no_below=no_below,
                            no_above=no_above,
                            rebuild=rebuild,
                            use_tfidf=False)
    voc_size = docSet.vocabsize

    n_topic = args.n_topic
    model = BATM(bow_dim=voc_size,
                 n_topic=n_topic,
                 device=device,
                 taskname=taskname)

    if bkpt_continue:
        path = os.listdir('./ckpt')[0]
        checkpoint = torch.load(os.path.join('./ckpt', path))
        model.generator.load_state_dict(checkpoint['generator'])
        model.encoder.load_state_dict(checkpoint['encoder'])
        model.discriminator.load_state_dict(checkpoint['discriminator'])

    model.train(train_data=docSet,
                batch_size=batch_size,
                test_data=docSet,
                num_epochs=num_epochs,
                log_every=10,
                n_critic=10)
    model.evaluate(test_data=docSet)
    save_name = f'./ckpt/BATM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'
    torch.save(
        {
            'generator': model.generator.state_dict(),
            'encoder': model.encoder.state_dict(),
            'discriminator': model.discriminator.state_dict()
        }, save_name)

    if show_topics:
        with open(f'./result/{taskname}_ep{num_epochs}.txt', 'w') as f:
            for topic in model.show_topic_words():
                print(topic, file=f)
예제 #8
0
def main():
    global args

    taskname = args.taskname
    no_below = args.no_below
    no_above = args.no_above
    num_iters = args.num_iters
    n_topic = args.n_topic
    n_cpu = cpu_count() - 2 if cpu_count() > 2 else 2
    bkpt_continue = args.bkpt_continue
    use_tfidf = args.use_tfidf
    rebuild = args.rebuild
    auto_adj = args.auto_adj

    docSet = DocDataset(taskname,
                        no_below=no_below,
                        no_above=no_above,
                        rebuild=rebuild)
    if auto_adj:
        no_above = docSet.topk_dfs(topk=20)
        docSet = DocDataset(taskname,
                            no_below=no_below,
                            no_above=no_above,
                            rebuild=rebuild,
                            use_tfidf=False)

    model_name = 'LDA'
    msg = 'bow' if not use_tfidf else 'tfidf'
    run_name = '{}_K{}_{}_{}'.format(model_name, n_topic, taskname, msg)
    if not os.path.exists('logs'):
        os.mkdir('logs')
    if not os.path.exists('ckpt'):
        os.mkdir('ckpt')
    loghandler = [
        logging.FileHandler(filename=f'logs/{run_name}.log', encoding="utf-8")
    ]
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(message)s',
                        handlers=loghandler)
    logger = logging.getLogger(__name__)

    if bkpt_continue:
        print('loading model ckpt ...')
        lda_model = gensim.models.ldamodel.LdaModel.load(
            'ckpt/{}.model'.format(run_name))

    # Training
    print('Start Training ...')

    if use_tfidf:
        tfidf = TfidfModel(docSet.bows)
        corpus_tfidf = tfidf[docSet.bows]
        #lda_model = LdaMulticore(list(corpus_tfidf),num_topics=n_topic,id2word=docSet.dictionary,alpha='asymmetric',passes=num_iters,workers=n_cpu,minimum_probability=0.0)
        lda_model = LdaModel(list(corpus_tfidf),
                             num_topics=n_topic,
                             id2word=docSet.dictionary,
                             alpha='asymmetric',
                             passes=num_iters)
    else:
        #lda_model = LdaMulticore(list(docSet.bows),num_topics=n_topic,id2word=docSet.dictionary,alpha='asymmetric',passes=num_iters,workers=n_cpu)
        lda_model = LdaModel(list(docSet.bows),
                             num_topics=n_topic,
                             id2word=docSet.dictionary,
                             alpha='asymmetric',
                             passes=num_iters)

    save_name = f'./ckpt/LDA_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'
    lda_model.save(save_name)

    # Evaluation
    print('Evaluation ...')
    topic_words = get_topic_words(model=lda_model,
                                  n_topic=n_topic,
                                  topn=15,
                                  vocab=docSet.dictionary)

    (cv_score, w2v_score, c_uci_score,
     c_npmi_score), _ = calc_topic_coherence(topic_words,
                                             docs=docSet.docs,
                                             dictionary=docSet.dictionary)

    topic_diversity = calc_topic_diversity(topic_words)

    result_dict = {
        'cv': cv_score,
        'w2v': w2v_score,
        'c_uci': c_uci_score,
        'c_npmi': c_npmi_score
    }
    logger.info('Topics:')

    for idx, words in enumerate(topic_words):
        logger.info(f'##{idx:>3d}:{words}')
        print(f'##{idx:>3d}:{words}')

    for measure, score in result_dict.items():
        logger.info(f'{measure} score: {score}')
        print(f'{measure} score: {score}')

    logger.info(f'topic diversity: {topic_diversity}')
    print(f'topic diversity: {topic_diversity}')
예제 #9
0
def main():
    global args
    taskname = args.taskname
    no_below = args.no_below
    no_above = args.no_above
    num_epochs = args.num_epochs
    n_topic = args.n_topic
    n_cpu = cpu_count() - 2 if cpu_count() > 2 else 2
    bkpt_continue = args.bkpt_continue
    use_tfidf = args.use_tfidf
    rebuild = args.rebuild
    batch_size = args.batch_size
    criterion = args.criterion
    n_topic = args.n_topic
    use_fc1 = args.use_fc1  #TBD_fc1
    emb_dim = args.emb_dim
    auto_adj = args.auto_adj

    device = torch.device('cuda')
    docSet = DocDataset(taskname,
                        no_below=no_below,
                        no_above=no_above,
                        rebuild=rebuild,
                        use_tfidf=False)
    if auto_adj:
        no_above = docSet.topk_dfs(topk=20)
        docSet = DocDataset(taskname,
                            no_below=no_below,
                            no_above=no_above,
                            rebuild=rebuild,
                            use_tfidf=False)

    voc_size = docSet.vocabsize
    print('voc size:', voc_size)
    model = ETM(bow_dim=voc_size,
                n_topic=n_topic,
                taskname=taskname,
                device=device,
                use_fc1=use_fc1,
                emb_dim=emb_dim)  #TBD_fc1
    model.train(train_data=docSet,
                batch_size=batch_size,
                test_data=docSet,
                num_epochs=num_epochs,
                log_every=10,
                beta=1.0,
                criterion=criterion)
    model.evaluate(test_data=docSet)
    save_name = f'./ckpt/ETM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'
    torch.save(model.vae.state_dict(), save_name)
    topic_vecs = model.vae.alpha.weight.detach().cpu().numpy()
    word_vecs = model.vae.rho.weight.detach().cpu().numpy()
    print('topic_vecs.shape:', topic_vecs.shape)
    print('word_vecs.shape:', word_vecs.shape)
    vocab = np.array([
        t[0] for t in sorted(list(docSet.dictionary.token2id.items()),
                             key=lambda x: x[1])
    ]).reshape(-1, 1)
    topic_ids = np.array([f'TP{i}' for i in range(n_topic)]).reshape(-1, 1)
    word_vecs = np.concatenate([vocab, word_vecs], axis=1)
    topic_vecs = np.concatenate([topic_ids, topic_vecs], axis=1)
    save_name_tp = f'./ckpt/TpVec_ETM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.emb'
    save_name_wd = f'./ckpt/WdVec_ETM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.emb'
    torch.save(topic_vecs, save_name_tp)
    torch.save(word_vecs, save_name_wd)
예제 #10
0
def main():
    global args
    taskname = args.taskname  # 数据集名字
    no_below = args.no_below  # 文档频率小于阈值的词会被过滤掉
    no_above = args.no_above  # 文档频率小于阈值的词将被过滤掉
    num_epochs = args.num_epochs  # 训练周期
    n_topic = args.n_topic  # 主题数
    n_cpu = cpu_count() - 2 if cpu_count() > 2 else 2
    bkpt_continue = args.bkpt_continue  # 是否在之前的checkoint上继续训练
    use_tfidf = args.use_tfidf  # 是否用tfidf作为BOW输入
    rebuild = args.rebuild  # 是否重建语料,默认不会
    batch_size = args.batch_size  # 批次大小
    criterion = args.criterion  # loss的种类
    auto_adj = args.auto_adj  # 是否自动调整频率,如去掉top20
    ckpt = args.ckpt  # ckpt路径

    device = torch.device('cpu')
    docSet = DocDataset(taskname,
                        no_below=no_below,
                        no_above=no_above,
                        rebuild=rebuild,
                        use_tfidf=False)  # 载入数据集,并分词
    if auto_adj:
        no_above = docSet.topk_dfs(topk=20)
        docSet = DocDataset(taskname,
                            no_below=no_below,
                            no_above=no_above,
                            rebuild=rebuild,
                            use_tfidf=False)

    voc_size = docSet.vocabsize
    print('voc size:', voc_size)

    if ckpt:  # 载入ckpt
        checkpoint = torch.load(ckpt)
        param.update({"device": device})
        model = GSM(**param)
        model.train(train_data=docSet,
                    batch_size=batch_size,
                    test_data=docSet,
                    num_epochs=num_epochs,
                    log_every=10,
                    beta=1.0,
                    criterion=criterion,
                    ckpt=checkpoint)
    else:
        # 初始化模型并开始执行train程序
        model = GSM(bow_dim=voc_size,
                    n_topic=n_topic,
                    taskname=taskname,
                    device=device)
        model.train(train_data=docSet,
                    batch_size=batch_size,
                    test_data=docSet,
                    num_epochs=num_epochs,
                    log_every=10,
                    beta=1.0,
                    criterion=criterion)
    model.evaluate(test_data=docSet)  # 用训练之后的模型做评估
    # 存模型,特征,统计等等结果
    save_name = f'./ckpt/GSM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'
    torch.save(model.vae.state_dict(), save_name)
    txt_lst, embeds = model.get_embed(train_data=docSet, num=1000)
    with open('topic_dist_gsm.txt', 'w', encoding='utf-8') as wfp:
        for t, e in zip(txt_lst, embeds):
            wfp.write(f'{e}:{t}\n')
    pickle.dump({
        'txts': txt_lst,
        'embeds': embeds
    }, open('gsm_embeds.pkl', 'wb'))