Esempio n. 1
0
def main():
    global args
    taskname = args.taskname
    no_below = args.no_below
    no_above = args.no_above
    num_epochs = args.num_epochs
    n_topic = args.n_topic
    n_cpu = cpu_count()-2 if cpu_count()>2 else 2
    bkpt_continue = args.bkpt_continue
    use_tfidf = args.use_tfidf
    rebuild = args.rebuild
    dist = args.dist 
    batch_size = args.batch_size
    criterion = args.criterion
    auto_adj = args.auto_adj

    device = torch.device('cuda')
    docSet = DocDataset(taskname,no_below=no_below,no_above=no_above,rebuild=rebuild,use_tfidf=False)
    if auto_adj:
        no_above = docSet.topk_dfs(topk=20)
        docSet = DocDataset(taskname,no_below=no_below,no_above=no_above,rebuild=rebuild,use_tfidf=False)
    
    voc_size = docSet.vocabsize
    print('voc size:',voc_size)
    n_topic = args.n_topic
    model = WTM(bow_dim=voc_size,n_topic=n_topic,device=device,dist=dist,taskname=taskname,dropout=0.4)
    model.train(train_data=docSet,batch_size=batch_size,test_data=docSet,num_epochs=num_epochs,log_every=10,beta=1.0)
    model.evaluate(test_data=docSet)
    save_name = f'./ckpt/WTM_{taskname}_tp{n_topic}_{dist}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'
    torch.save(model.wae.state_dict(),save_name)
    txt_lst, embeds = model.get_embed(train_data=docSet, num=1000)
    torch.save({'txts':txt_lst,'embeds':embeds},'wtm_embeds.pkl')
def main():
    global args
    taskname = args.taskname
    no_below = args.no_below
    no_above = args.no_above
    num_epochs = args.num_epochs
    n_topic = args.n_topic
    n_cpu = cpu_count() - 2 if cpu_count() > 2 else 2
    bkpt_continue = args.bkpt_continue
    use_tfidf = args.use_tfidf
    rebuild = args.rebuild
    batch_size = args.batch_size
    criterion = args.criterion
    n_topic = args.n_topic
    auto_adj = args.auto_adj
    show_topics = args.show_topics

    device = torch.device('cpu')
    docSet = DocDataset(taskname,
                        no_below=no_below,
                        no_above=no_above,
                        rebuild=rebuild,
                        use_tfidf=False)
    if auto_adj:
        no_above = docSet.topk_dfs(topk=20)
        docSet = DocDataset(taskname,
                            no_below=no_below,
                            no_above=no_above,
                            rebuild=rebuild,
                            use_tfidf=False)

    voc_size = docSet.vocabsize
    print('voc size:', voc_size)
    model = GSM(bow_dim=voc_size,
                n_topic=n_topic,
                taskname=taskname,
                device=device)
    if bkpt_continue:
        path = os.listdir('./ckpt')[0]
        checkpoint = torch.load(os.path.join('./ckpt', path))
        model.vae.load_state_dict(checkpoint)
    model.train(train_data=docSet,
                batch_size=batch_size,
                test_data=docSet,
                num_epochs=num_epochs,
                log_every=10,
                beta=1.0,
                criterion=criterion)
    model.evaluate(test_data=docSet)

    if show_topics:
        with open(f'./result/{taskname}_ep{num_epochs}.txt', 'w') as f:
            for topic in model.show_topic_words():
                print(topic, file=f)

    save_name = f'./ckpt/GSM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'
    torch.save(model.vae.state_dict(), save_name)
Esempio n. 3
0
def main():
    global args
    taskname = args.taskname
    no_below = args.no_below
    no_above = args.no_above
    num_epochs = args.num_epochs
    n_topic = args.n_topic
    n_cpu = cpu_count()-2 if cpu_count()>2 else 2
    bkpt_continue = args.bkpt_continue
    use_tfidf = args.use_tfidf
    rebuild = args.rebuild
    batch_size = args.batch_size
    criterion = args.criterion
    n_topic = args.n_topic
    emb_dim = args.emb_dim
    auto_adj = args.auto_adj

    device = torch.device('cuda')
    docSet = DocDataset(taskname,no_below=no_below,no_above=no_above,rebuild=rebuild,use_tfidf=False)
    if auto_adj:
        no_above = docSet.topk_dfs(topk=20)
        docSet = DocDataset(taskname,no_below=no_below,no_above=no_above,rebuild=rebuild,use_tfidf=False)
    
    voc_size = docSet.vocabsize
    print('voc size:',voc_size)
    model = ETM(bow_dim=voc_size,n_topic=n_topic,taskname=taskname,device=device,emb_dim=emb_dim) #TBD_fc1
    model.train(train_data=docSet,batch_size=batch_size,test_data=docSet,num_epochs=num_epochs,log_every=10,beta=1.0,criterion=criterion)
    model.evaluate(test_data=docSet)
    save_name = f'./ckpt/ETM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'
    torch.save(model.vae.state_dict(),save_name)
    topic_vecs = model.vae.alpha.weight.detach().cpu().numpy()
    word_vecs = model.vae.rho.weight.detach().cpu().numpy()
    print('topic_vecs.shape:',topic_vecs.shape)
    print('word_vecs.shape:',word_vecs.shape)
    vocab = np.array([t[0] for t in sorted(list(docSet.dictionary.token2id.items()),key=lambda x: x[1])]).reshape(-1,1)
    topic_ids = np.array([f'TP{i}' for i in range(n_topic)]).reshape(-1,1)
    word_vecs = np.concatenate([vocab,word_vecs],axis=1)
    topic_vecs = np.concatenate([topic_ids,topic_vecs],axis=1)
    #save_name_tp = f'./ckpt/TpVec_ETM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.emb'
    save_name_wd = f'./ckpt/WdVec_ETM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.emb'
    n_instances = word_vecs.shape[0]+topic_vecs.shape[0]
    with open(save_name_wd,'w',encoding='utf-8') as wfp:
        wfp.write(f'{n_instances} {emb_dim}\n')
        wfp.write('\n'.join([' '.join(e) for e in word_vecs]+[' '.join(e) for e in topic_vecs]))
    from gensim.models import KeyedVectors
    w2v = KeyedVectors.load_word2vec_format(save_name_wd,binary=False)
    w2v.save(save_name.split('.')[0]+'.w2v')
    print(w2v.vocab.keys())
    #w2v.most_similar('你好')
    for i in range(n_topic):
        print(f'Most similar to Topic {i}')
        print(w2v.most_similar(f'TP{i}'))
    txt_lst, embeds = model.get_embed(train_data=docSet, num=1000)
    with open('topic_dist_etm.txt','w',encoding='utf-8') as wfp:
        for t,e in zip(txt_lst,embeds):
            wfp.write(f'{e}:{t}\n')
    pickle.dump({'txts':txt_lst,'embeds':embeds},open('etm_embeds.pkl','wb'))
Esempio n. 4
0
def main():
    global args
    taskname = args.taskname
    no_below = args.no_below
    no_above = args.no_above
    num_epochs = args.num_epochs
    n_topic = args.n_topic
    n_cpu = cpu_count() - 2 if cpu_count() > 2 else 2
    bkpt_continue = args.bkpt_continue
    use_tfidf = args.use_tfidf
    rebuild = args.rebuild
    batch_size = args.batch_size
    criterion = args.criterion
    auto_adj = args.auto_adj

    device = torch.device('cuda')
    docSet = DocDataset(taskname,
                        no_below=no_below,
                        no_above=no_above,
                        rebuild=rebuild,
                        use_tfidf=False)
    if auto_adj:
        no_above = docSet.topk_dfs(topk=20)
        docSet = DocDataset(taskname,
                            no_below=no_below,
                            no_above=no_above,
                            rebuild=rebuild,
                            use_tfidf=False)

    voc_size = docSet.vocabsize
    print('voc size:', voc_size)
    n_topic = args.n_topic

    model = GMNTM(bow_dim=voc_size,
                  n_topic=n_topic,
                  device=device,
                  taskname=taskname,
                  dropout=0.2)
    model.train(train_data=docSet,
                batch_size=batch_size,
                test_data=docSet,
                num_epochs=num_epochs,
                log_every=10,
                beta=1.0,
                criterion='bce_softmax')
    model.evaluate(test_data=docSet)
    save_name = f'./ckpt/GMNTM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'
    torch.save(model.vade.state_dict(), save_name)
    txt_lst, embeds = model.get_embed(train_data=docSet, num=1000)
    with open('topic_dist_gmntm.txt', 'w', encoding='utf-8') as wfp:
        for t, e in zip(txt_lst, embeds):
            wfp.write(f'{e}:{t}\n')
    pickle.dump({
        'txts': txt_lst,
        'embeds': embeds
    }, open('gmntm_embeds.pkl', 'wb'))
def main():
    global args
    taskname = args.taskname
    no_below = args.no_below
    no_above = args.no_above
    num_epochs = args.num_epochs
    n_topic = args.n_topic
    n_cpu = cpu_count() - 2 if cpu_count() > 2 else 2
    bkpt_continue = args.bkpt_continue
    use_tfidf = args.use_tfidf
    rebuild = args.rebuild
    dist = args.dist
    batch_size = args.batch_size
    criterion = args.criterion
    auto_adj = args.auto_adj

    device = torch.device('cuda')
    docSet = DocDataset(taskname,
                        no_below=no_below,
                        no_above=no_above,
                        rebuild=rebuild,
                        use_tfidf=True,
                        lang='en')
    if auto_adj:
        no_above = docSet.topk_dfs(topk=20)
        docSet = DocDataset(taskname,
                            no_below=no_below,
                            no_above=no_above,
                            rebuild=rebuild,
                            use_tfidf=False)
    voc_size = docSet.vocabsize

    n_topic = args.n_topic
    model = BATM(bow_dim=voc_size,
                 n_topic=n_topic,
                 device=device,
                 taskname=taskname)
    model.train(train_data=docSet,
                batch_size=batch_size,
                test_data=docSet,
                num_epochs=num_epochs,
                log_every=10,
                n_critic=10)
    model.evaluate(test_data=docSet)
    save_name = f'./ckpt/BATM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'
    torch.save(
        {
            'generator': model.generator.state_dict(),
            'encoder': model.encoder.state_dict(),
            'discriminator': model.discriminator.state_dict()
        }, save_name)
Esempio n. 6
0
def main():
    print("loading model")
    model_t = keras.models.load_model('model/model_t.h5', compile=False)
    model_r = keras.models.load_model('model/model_r.h5', compile=False)

    ds = DocDataset()

    x_train_snicor, _, _ = ds.load_train_data()
    x_test_snicer, x_test_boot = ds.load_test_data()

    train = model_t.predict(x_train_snicor)
    test_s = model_t.predict(x_test_snicer)
    test_b = model_t.predict(x_test_boot)

    train = train.reshape((len(x_train_snicor), -1))
    test_s = test_s.reshape((len(x_test_snicer), -1))
    test_b = test_b.reshape((len(x_test_boot), -1))

    #0-1に変換
    ms = MinMaxScaler()
    train = ms.fit_transform(train)
    test_s = ms.transform(test_s)
    test_b = ms.transform(test_b)

    # fit the model
    clf = LocalOutlierFactor(n_neighbors=5)
    y_pred = clf.fit(train)

    # 異常スコア
    Z1 = -clf._decision_function(test_s)
    Z2 = -clf._decision_function(test_b)

    #ROC曲線の描画
    y_true = np.zeros(len(test_s) + len(test_b))
    y_true[len(test_s):] = 1  #0:正常、1:異常

    # FPR, TPR(, しきい値) を算出
    fpr, tpr, _ = metrics.roc_curve(y_true, np.hstack((Z1, Z2)))

    # AUC
    auc = metrics.auc(fpr, tpr)

    # ROC曲線をプロット
    plt.plot(fpr, tpr, label='DeepOneClassification(AUC = %.2f)' % auc)
    plt.legend()
    plt.title('ROC curve')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.grid(True)
    plt.show()
Esempio n. 7
0
def main():
    global args
    taskname = args.taskname
    no_below = args.no_below
    no_above = args.no_above
    num_epochs = args.num_epochs
    n_topic = args.n_topic
    n_cpu = cpu_count() - 2 if cpu_count() > 2 else 2
    bkpt_continue = args.bkpt_continue
    use_tfidf = args.use_tfidf
    rebuild = args.rebuild
    batch_size = args.batch_size
    criterion = args.criterion
    auto_adj = args.auto_adj

    device = torch.device('cuda')
    docSet = DocDataset(taskname,
                        no_below=no_below,
                        no_above=no_above,
                        rebuild=rebuild,
                        use_tfidf=False)
    if auto_adj:
        no_above = docSet.topk_dfs(topk=20)
        docSet = DocDataset(taskname,
                            no_below=no_below,
                            no_above=no_above,
                            rebuild=rebuild,
                            use_tfidf=False)

    voc_size = docSet.vocabsize
    print('voc size:', voc_size)
    n_topic = args.n_topic

    model = GMNTM(bow_dim=voc_size,
                  n_topic=n_topic,
                  device=device,
                  taskname=taskname,
                  dropout=0.2)
    model.train(train_data=docSet,
                batch_size=batch_size,
                test_data=docSet,
                num_epochs=num_epochs,
                log_every=10,
                beta=1.0,
                criterion='bce_softmax')
    model.evaluate(test_data=docSet)
    save_name = f'./ckpt/GMNTM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'
    torch.save(model.vade.state_dict(), save_name)
Esempio n. 8
0
def main():
    print("loading model")
    model_r, model_t = DocNetWork()

    # dataset
    ds = DocDataset()
    x_train_snicor, x_train_ref, y_train_ref = ds.load_train_data()

    # 損失関数の値を保存
    loss_c, loss_d = [], []

    # ターゲットデータセットは、one class
    for i in range(epoch_num):
        tmp_lc, tmp_ld = [], []
        print(("epoch {} training  network ... ").format(i))

        for i in range(int(len(x_train_snicor) / batchsize)):
            range_s = batchsize * i
            range_e = batchsize * ( i + 1 ) 
            batch_target = x_train_snicor[range_s:range_e]
            batch_ref = x_train_ref[range_s:range_e]
            batch_y = y_train_ref[range_s:range_e]

            print("training target network ... ")
            tmp_lc.append(model_t.train_on_batch(batch_target, np.zeros((batchsize, feature_out))))
            print("training reference network ... ")
            tmp_ld.append(model_r.train_on_batch(batch_ref, batch_y))

        loss_c.append(np.mean(tmp_lc))
        loss_d.append(np.mean(tmp_ld))

        if (epoch_num+1) % 5 == 0:
            print("epoch:",epoch_num+1)
            print("Compact loss", loss_c[-1])
            print("Descriptive loss:", loss_d[-1])

    if not os.path.exists(model_path):
        os.mkdir(model_path)
    model_t.save(os.path.join(model_path, "model_t.h5"), include_optimizer=False)
    model_r.save(os.path.join(model_path, "model_r.h5"), include_optimizer=False)
Esempio n. 9
0
def main():
    global args
    train_data_file = args.train_data_file
    test_data_file = args.test_data_file
    no_below = args.no_below
    no_above = args.no_above
    num_epochs = args.num_epochs
    n_topic = args.n_topic
    hidden_size = args.hidden_size
    learning_rate = args.learning_rate
    log_every = args.log_every
    n_cpu = cpu_count() - 2 if cpu_count() > 2 else 2
    bkpt_continue = args.bkpt_continue
    use_tfidf = args.use_tfidf
    rebuild = args.rebuild
    batch_size = args.batch_size
    auto_adj = args.auto_adj
    ckpt = args.ckpt

    device = torch.device('cuda')
    trainSet = DocDataset(train_data_file,
                          no_below=no_below,
                          no_above=no_above,
                          rebuild=rebuild,
                          use_tfidf=False)
    testSet = DocDataset(test_data_file,
                         no_below=no_below,
                         no_above=no_above,
                         rebuild=rebuild,
                         use_tfidf=False)
    # if auto_adj:
    #     no_above = docSet.topk_dfs(topk=20)
    #     docSet = DocDataset(taskname,no_below=no_below,no_above=no_above,rebuild=rebuild,use_tfidf=False)

    voc_size = trainSet.vocabsize
    print('train voc size:', voc_size)
    print("train:", type(trainSet), len(trainSet))
    print("test:", type(testSet), len(testSet))
    if ckpt:
        checkpoint = torch.load(ckpt)
        param.update({"device": device})
        model = GSM(**param)
        model.train(train_data=trainSet,
                    test_data=testSet,
                    batch_size=batch_size,
                    learning_rate=learning_rate,
                    num_epochs=num_epochs,
                    log_every=log_every,
                    ckpt=checkpoint)
    else:
        model = GSM(bow_dim=voc_size,
                    n_topic=n_topic,
                    hidden_size=hidden_size,
                    device=device)
        model.train(train_data=trainSet,
                    test_data=testSet,
                    batch_size=batch_size,
                    learning_rate=learning_rate,
                    num_epochs=num_epochs,
                    log_every=log_every)
    #model.evaluate(test_data=docSet)
    save_name = f'./ckpt/GSM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'
    torch.save(model.vae.state_dict(), save_name)
    txt_lst, embeds = model.get_embed(train_data=docSet, num=1000)
    with open('topic_dist_gsm.txt', 'w', encoding='utf-8') as wfp:
        for t, e in zip(txt_lst, embeds):
            wfp.write(f'{e}:{t}\n')
    pickle.dump({
        'txts': txt_lst,
        'embeds': embeds
    }, open('gsm_embeds.pkl', 'wb'))
Esempio n. 10
0
def main():
    global args
    taskname = args.taskname
    no_below = args.no_below
    no_above = args.no_above
    num_epochs = args.num_epochs
    n_topic = args.n_topic
    n_cpu = cpu_count() - 2 if cpu_count() > 2 else 2
    bkpt_continue = args.bkpt_continue
    use_tfidf = args.use_tfidf
    rebuild = args.rebuild
    dist = args.dist
    batch_size = args.batch_size
    criterion = args.criterion
    auto_adj = args.auto_adj
    show_topics = args.show_topics

    device = torch.device('cpu')

    docSet = DocDataset(taskname,
                        no_below=no_below,
                        no_above=no_above,
                        rebuild=rebuild,
                        use_tfidf=True)
    if auto_adj:
        no_above = docSet.topk_dfs(topk=20)
        docSet = DocDataset(taskname,
                            no_below=no_below,
                            no_above=no_above,
                            rebuild=rebuild,
                            use_tfidf=False)
    voc_size = docSet.vocabsize

    n_topic = args.n_topic
    model = BATM(bow_dim=voc_size,
                 n_topic=n_topic,
                 device=device,
                 taskname=taskname)

    if bkpt_continue:
        path = os.listdir('./ckpt')[0]
        checkpoint = torch.load(os.path.join('./ckpt', path))
        model.generator.load_state_dict(checkpoint['generator'])
        model.encoder.load_state_dict(checkpoint['encoder'])
        model.discriminator.load_state_dict(checkpoint['discriminator'])

    model.train(train_data=docSet,
                batch_size=batch_size,
                test_data=docSet,
                num_epochs=num_epochs,
                log_every=10,
                n_critic=10)
    model.evaluate(test_data=docSet)
    save_name = f'./ckpt/BATM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'
    torch.save(
        {
            'generator': model.generator.state_dict(),
            'encoder': model.encoder.state_dict(),
            'discriminator': model.discriminator.state_dict()
        }, save_name)

    if show_topics:
        with open(f'./result/{taskname}_ep{num_epochs}.txt', 'w') as f:
            for topic in model.show_topic_words():
                print(topic, file=f)
Esempio n. 11
0
def main():
    global args

    taskname = args.taskname
    no_below = args.no_below
    no_above = args.no_above
    num_iters = args.num_iters
    n_topic = args.n_topic
    n_cpu = cpu_count() - 2 if cpu_count() > 2 else 2
    bkpt_continue = args.bkpt_continue
    use_tfidf = args.use_tfidf
    rebuild = args.rebuild
    auto_adj = args.auto_adj

    docSet = DocDataset(taskname,
                        no_below=no_below,
                        no_above=no_above,
                        rebuild=rebuild)
    if auto_adj:
        no_above = docSet.topk_dfs(topk=20)
        docSet = DocDataset(taskname,
                            no_below=no_below,
                            no_above=no_above,
                            rebuild=rebuild,
                            use_tfidf=False)

    model_name = 'LDA'
    msg = 'bow' if not use_tfidf else 'tfidf'
    run_name = '{}_K{}_{}_{}'.format(model_name, n_topic, taskname, msg)
    if not os.path.exists('logs'):
        os.mkdir('logs')
    if not os.path.exists('ckpt'):
        os.mkdir('ckpt')
    loghandler = [
        logging.FileHandler(filename=f'logs/{run_name}.log', encoding="utf-8")
    ]
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(message)s',
                        handlers=loghandler)
    logger = logging.getLogger(__name__)

    if bkpt_continue:
        print('loading model ckpt ...')
        lda_model = gensim.models.ldamodel.LdaModel.load(
            'ckpt/{}.model'.format(run_name))

    # Training
    print('Start Training ...')

    if use_tfidf:
        tfidf = TfidfModel(docSet.bows)
        corpus_tfidf = tfidf[docSet.bows]
        #lda_model = LdaMulticore(list(corpus_tfidf),num_topics=n_topic,id2word=docSet.dictionary,alpha='asymmetric',passes=num_iters,workers=n_cpu,minimum_probability=0.0)
        lda_model = LdaModel(list(corpus_tfidf),
                             num_topics=n_topic,
                             id2word=docSet.dictionary,
                             alpha='asymmetric',
                             passes=num_iters)
    else:
        #lda_model = LdaMulticore(list(docSet.bows),num_topics=n_topic,id2word=docSet.dictionary,alpha='asymmetric',passes=num_iters,workers=n_cpu)
        lda_model = LdaModel(list(docSet.bows),
                             num_topics=n_topic,
                             id2word=docSet.dictionary,
                             alpha='asymmetric',
                             passes=num_iters)

    save_name = f'./ckpt/LDA_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'
    lda_model.save(save_name)

    # Evaluation
    print('Evaluation ...')
    topic_words = get_topic_words(model=lda_model,
                                  n_topic=n_topic,
                                  topn=15,
                                  vocab=docSet.dictionary)

    (cv_score, w2v_score, c_uci_score,
     c_npmi_score), _ = calc_topic_coherence(topic_words,
                                             docs=docSet.docs,
                                             dictionary=docSet.dictionary)

    topic_diversity = calc_topic_diversity(topic_words)

    result_dict = {
        'cv': cv_score,
        'w2v': w2v_score,
        'c_uci': c_uci_score,
        'c_npmi': c_npmi_score
    }
    logger.info('Topics:')

    for idx, words in enumerate(topic_words):
        logger.info(f'##{idx:>3d}:{words}')
        print(f'##{idx:>3d}:{words}')

    for measure, score in result_dict.items():
        logger.info(f'{measure} score: {score}')
        print(f'{measure} score: {score}')

    logger.info(f'topic diversity: {topic_diversity}')
    print(f'topic diversity: {topic_diversity}')
Esempio n. 12
0
def main():
    global args
    taskname = args.taskname
    no_below = args.no_below
    no_above = args.no_above
    num_epochs = args.num_epochs
    n_topic = args.n_topic
    n_cpu = cpu_count() - 2 if cpu_count() > 2 else 2
    bkpt_continue = args.bkpt_continue
    use_tfidf = args.use_tfidf
    rebuild = args.rebuild
    batch_size = args.batch_size
    criterion = args.criterion
    n_topic = args.n_topic
    use_fc1 = args.use_fc1  #TBD_fc1
    emb_dim = args.emb_dim
    auto_adj = args.auto_adj

    device = torch.device('cuda')
    docSet = DocDataset(taskname,
                        no_below=no_below,
                        no_above=no_above,
                        rebuild=rebuild,
                        use_tfidf=False)
    if auto_adj:
        no_above = docSet.topk_dfs(topk=20)
        docSet = DocDataset(taskname,
                            no_below=no_below,
                            no_above=no_above,
                            rebuild=rebuild,
                            use_tfidf=False)

    voc_size = docSet.vocabsize
    print('voc size:', voc_size)
    model = ETM(bow_dim=voc_size,
                n_topic=n_topic,
                taskname=taskname,
                device=device,
                use_fc1=use_fc1,
                emb_dim=emb_dim)  #TBD_fc1
    model.train(train_data=docSet,
                batch_size=batch_size,
                test_data=docSet,
                num_epochs=num_epochs,
                log_every=10,
                beta=1.0,
                criterion=criterion)
    model.evaluate(test_data=docSet)
    save_name = f'./ckpt/ETM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'
    torch.save(model.vae.state_dict(), save_name)
    topic_vecs = model.vae.alpha.weight.detach().cpu().numpy()
    word_vecs = model.vae.rho.weight.detach().cpu().numpy()
    print('topic_vecs.shape:', topic_vecs.shape)
    print('word_vecs.shape:', word_vecs.shape)
    vocab = np.array([
        t[0] for t in sorted(list(docSet.dictionary.token2id.items()),
                             key=lambda x: x[1])
    ]).reshape(-1, 1)
    topic_ids = np.array([f'TP{i}' for i in range(n_topic)]).reshape(-1, 1)
    word_vecs = np.concatenate([vocab, word_vecs], axis=1)
    topic_vecs = np.concatenate([topic_ids, topic_vecs], axis=1)
    save_name_tp = f'./ckpt/TpVec_ETM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.emb'
    save_name_wd = f'./ckpt/WdVec_ETM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.emb'
    torch.save(topic_vecs, save_name_tp)
    torch.save(word_vecs, save_name_wd)
Esempio n. 13
0
def main():
    global args
    taskname = args.taskname  # 数据集名字
    no_below = args.no_below  # 文档频率小于阈值的词会被过滤掉
    no_above = args.no_above  # 文档频率小于阈值的词将被过滤掉
    num_epochs = args.num_epochs  # 训练周期
    n_topic = args.n_topic  # 主题数
    n_cpu = cpu_count() - 2 if cpu_count() > 2 else 2
    bkpt_continue = args.bkpt_continue  # 是否在之前的checkoint上继续训练
    use_tfidf = args.use_tfidf  # 是否用tfidf作为BOW输入
    rebuild = args.rebuild  # 是否重建语料,默认不会
    batch_size = args.batch_size  # 批次大小
    criterion = args.criterion  # loss的种类
    auto_adj = args.auto_adj  # 是否自动调整频率,如去掉top20
    ckpt = args.ckpt  # ckpt路径

    device = torch.device('cpu')
    docSet = DocDataset(taskname,
                        no_below=no_below,
                        no_above=no_above,
                        rebuild=rebuild,
                        use_tfidf=False)  # 载入数据集,并分词
    if auto_adj:
        no_above = docSet.topk_dfs(topk=20)
        docSet = DocDataset(taskname,
                            no_below=no_below,
                            no_above=no_above,
                            rebuild=rebuild,
                            use_tfidf=False)

    voc_size = docSet.vocabsize
    print('voc size:', voc_size)

    if ckpt:  # 载入ckpt
        checkpoint = torch.load(ckpt)
        param.update({"device": device})
        model = GSM(**param)
        model.train(train_data=docSet,
                    batch_size=batch_size,
                    test_data=docSet,
                    num_epochs=num_epochs,
                    log_every=10,
                    beta=1.0,
                    criterion=criterion,
                    ckpt=checkpoint)
    else:
        # 初始化模型并开始执行train程序
        model = GSM(bow_dim=voc_size,
                    n_topic=n_topic,
                    taskname=taskname,
                    device=device)
        model.train(train_data=docSet,
                    batch_size=batch_size,
                    test_data=docSet,
                    num_epochs=num_epochs,
                    log_every=10,
                    beta=1.0,
                    criterion=criterion)
    model.evaluate(test_data=docSet)  # 用训练之后的模型做评估
    # 存模型,特征,统计等等结果
    save_name = f'./ckpt/GSM_{taskname}_tp{n_topic}_{time.strftime("%Y-%m-%d-%H-%M", time.localtime())}.ckpt'
    torch.save(model.vae.state_dict(), save_name)
    txt_lst, embeds = model.get_embed(train_data=docSet, num=1000)
    with open('topic_dist_gsm.txt', 'w', encoding='utf-8') as wfp:
        for t, e in zip(txt_lst, embeds):
            wfp.write(f'{e}:{t}\n')
    pickle.dump({
        'txts': txt_lst,
        'embeds': embeds
    }, open('gsm_embeds.pkl', 'wb'))