예제 #1
0
def get_data(df):
    """
    The function implements converting DataFrame with information about MCC-codes and costs of the client
     to his profile. The profile include information about behaviour pattern, gender and age.

    :param df: pandas.DataFrame with two columns: first with MCC, second with costs on the MCC.
    :return: python dict
    """
    topic2name = pickle.load(open('../backend/data/topic2name.pkl', 'rb'))
    with open('../backend/tmp/1.vw', 'w') as f:
        f.write('0 |@default')
        for i, j in zip(df.values[:, 0], df.values[:, 1]):
            f.write(' {}:{}'.format(int(i), int(j)))
        f.write('\n')

    batch_vectorizer = artm.BatchVectorizer(target_folder='../backend/tmp',
                                            data_path='../backend/tmp/1.vw',
                                            data_format='vowpal_wabbit')

    model = artm.load_artm_model('../backend/data/reg_plsa_nogrp_mlt')
    profile = model.transform(batch_vectorizer)
    top_topics = profile.sort_values(by=profile.columns[0], ascending=False).index[:5].values
    top_topics = [topic2name[i] for i in top_topics]
    top_topics_score = profile.sort_values(by=profile.columns[0], ascending=False).values[:5].reshape(5)
    top_topics = dict(zip(top_topics, top_topics_score))
    gender = model.get_phi(class_ids=['@gender']).loc[['M', 'F']].values.dot(profile.values).reshape(2)
    age = model.get_phi(class_ids=['@age']).loc[['teen', 'young', 'midage', 'elderly']].values.dot(
        profile.values).reshape(4)

    filelist = glob('../backend/tmp/*')
    for f in filelist:
        os.remove(f)
    return {'top_topics': top_topics, 'gender': list(gender), 'age': list(age)}
예제 #2
0
def main() -> None:
    """
    Predict topics for given issue in input
    :return: None
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-artm",
                        required=True,
                        help="Path to directory with BigARTM model")
    args = parser.parse_args()
    model_artm = artm.load_artm_model(args.model_artm)

    with open('topic_issue_model.pickle', 'rb') as issue_pickle_file:
        topic_issue_model: TopicIssueModel = pickle.load(issue_pickle_file)
    predict_topics(topic_issue_model, model_artm)
예제 #3
0
    def load(path, experiment=None):
        """
        Loads the model.

        Parameters
        ----------
        path : str
            path to the model's folder
        experiment : Experiment

        Returns
        -------
        TopicModel

        """
        from ..experiment import Experiment, START

        if "model" in os.listdir(f"{path}"):
            model = artm.load_artm_model(f"{path}/model")
        else:
            model = None
            print("There is no dumped model. You should train it again.")
        params = json.load(open(f"{path}/params.json",
                                "r"))  # TODO: add property with such name?
        topic_model = TopicModel(model, **params)

        crunch_resolve = experiment or topic_model.model_id == START
        if crunch_resolve:
            topic_model.experiment = experiment
        elif params["experiment_id"] is not None:
            experiment_path = path[:path.rfind(topic_model.model_id)]
            if params["experiment_id"] in experiment_path.split('/'):
                topic_model.experiment = Experiment.load(experiment_path)

        custom_scores = {}
        for score_path in glob.glob(os.path.join(path, '*.p')):
            score_name = os.path.basename(score_path).split('.')[0]
            custom_scores[score_name] = pickle.load(open(score_path, 'rb'))

        topic_model.custom_scores = custom_scores
        return topic_model
예제 #4
0
def get_data(df):
    """
    The function implements converting DataFrame with information about MCC-codes and costs of the client
     to his profile. The profile include information about behaviour pattern, gender and age.

    :param df: pandas.DataFrame with two columns: first with MCC, second with costs on the MCC.
    :return: python dict
    """
    topic2name = pickle.load(open('../backend/data/topic2name.pkl', 'rb'))
    with open('../backend/tmp/1.vw', 'w') as f:
        f.write('0 |@default')
        for i, j in zip(df.values[:, 0], df.values[:, 1]):
            f.write(' {}:{}'.format(int(i), int(j)))
        f.write('\n')

    batch_vectorizer = artm.BatchVectorizer(target_folder='../backend/tmp',
                                            data_path='../backend/tmp/1.vw',
                                            data_format='vowpal_wabbit')

    model = artm.load_artm_model('../backend/data/reg_plsa_nogrp_mlt')
    profile = model.transform(batch_vectorizer)
    top_topics = profile.sort_values(by=profile.columns[0],
                                     ascending=False).index[:5].values
    top_topics = [topic2name[i] for i in top_topics]
    top_topics_score = profile.sort_values(
        by=profile.columns[0], ascending=False).values[:5].reshape(5)
    top_topics = dict(zip(top_topics, top_topics_score))
    gender = model.get_phi(class_ids=['@gender']).loc[['M', 'F']].values.dot(
        profile.values).reshape(2)
    age = model.get_phi(class_ids=['@age']).loc[[
        'teen', 'young', 'midage', 'elderly'
    ]].values.dot(profile.values).reshape(4)

    filelist = glob('../backend/tmp/*')
    for f in filelist:
        os.remove(f)
    return {'top_topics': top_topics, 'gender': list(gender), 'age': list(age)}
예제 #5
0
DATA_DIR = r"../Files"
ORGANIZATION_FILE = os.path.join(DATA_DIR, "organization.json")

N_DOCUMENTS = 10000
BASE_DIR = "../Files/TopicModeling/{}_documents".format(N_DOCUMENTS)
SAVE_DIR = os.path.join(BASE_DIR, "models/artm")
THETA_FILE = os.path.join(BASE_DIR, "theta.pkl")
PHI_FILE = os.path.join(BASE_DIR, "phi.pkl")
ID_TO_ROW_FILE_NAME = os.path.join(
    "../Files/DataPreprocessing",
    "{name}_{n_docs}documents_{n_feats}features".format(name="id_to_row",
                                                        n_docs=N_DOCUMENTS,
                                                        n_feats=1000))

model_artm = artm.load_artm_model(SAVE_DIR)

organization = Organization.Organization.LoadFromJson(ORGANIZATION_FILE)
employees = list(organization.GetAllEmployees())

theta = pd.read_pickle(THETA_FILE).values
with open(ID_TO_ROW_FILE_NAME, "rb") as id_to_row_file:
    id_to_row = pickle.load(id_to_row_file)

employee_to_vectors = {}
for employee in employees:
    employee_to_vectors[employee] = []
    for document_id in employee.Documents:
        if document_id in id_to_row:
            employee_to_vectors[employee].append(theta[:,
                                                       id_to_row[document_id]])
예제 #6
0
def test_func():
    num_topics = 5
    tolerance = 0.01
    batches_folder = tempfile.mkdtemp()

    try:
        with open(os.path.join(batches_folder, 'temp.vw.txt'), 'w') as fout:
            fout.write('title_0 aaa:6 bbb:3 ccc:2 |@time_class time_1\n')
            fout.write('title_1 aaa:2 bbb:9 ccc:3\n')
            fout.write('title_2 aaa:1 bbb:2 ccc:7 |@time_class time_2\n')
            fout.write('title_3 aaa:7 bbb:4 ccc:5 |@time_class time_2\n')

        batch_vectorizer = artm.BatchVectorizer(data_path=os.path.join(batches_folder, 'temp.vw.txt'),
                                                data_format='vowpal_wabbit',
                                                target_folder=batches_folder)
        # configure model 1
        model = artm.ARTM(num_topics=num_topics,
                          dictionary=batch_vectorizer.dictionary,
                          num_document_passes=1)

        reg = artm.NetPlsaPhiRegularizer(name='net_plsa', tau=1.0, class_id='@time_class',
                                         vertex_names=['time_1', 'time_2'], vertex_weights=[1.0, 2.0],
                                         edge_weights={0: {1: 3.0}, 1: {0: 2.0}})
        model.regularizers.add(reg)

        # configure model 2
        model_2 = artm.ARTM(num_topics=num_topics,
                            dictionary=batch_vectorizer.dictionary,
                            num_document_passes=1)

        model_2.regularizers.add(artm.NetPlsaPhiRegularizer(name='net_plsa', tau=1.0))
        model_2.regularizers['net_plsa'].class_id = '@time_class'
        model_2.regularizers['net_plsa'].vertex_names = ['time_1', 'time_2']
        model_2.regularizers['net_plsa'].vertex_weights = [1.0, 2.0]
        model_2.regularizers['net_plsa'].edge_weights = {0: {1: 3.0}, 1: {0: 2.0}}

        model.fit_offline(batch_vectorizer=batch_vectorizer, num_collection_passes=2)
        model_2.fit_offline(batch_vectorizer=batch_vectorizer, num_collection_passes=2)

        phi = model.get_phi()
        phi_2 = model_2.get_phi()
        assert phi.equals(phi_2)

        model.dump_artm_model(os.path.join(batches_folder, 'target'))
        model_3 = artm.load_artm_model(os.path.join(batches_folder, 'target'))

        model.fit_offline(batch_vectorizer=batch_vectorizer, num_collection_passes=1)
        model_3.fit_offline(batch_vectorizer=batch_vectorizer, num_collection_passes=1)

        phi = model.get_phi()
        phi_3 = model_3.get_phi()
        assert phi.equals(phi_3)

        def _f(w):
            return ('@default_class', w)

        def _t(w):
            return ('@time_class', w)

        real_topics = pd.DataFrame(columns=['topic_0', 'topic_1', 'topic_2', 'topic_3', 'topic_4'],
                                   index=[_f('ccc'), _f('bbb'), _f('aaa'), _t('time_1'), _t('time_2')],
                                   data=[[0.098, 0.892, 0.099, 0.389, 0.184],
                                         [0.145, 0.004, 0.618, 0.334, 0.684],
                                         [0.757, 0.104, 0.283, 0.277, 0.132],
                                         [0.06,  0.0,   0.092, 0.0,   0.0  ],
                                         [0.94,  1.0,   0.908, 1.0,   1.0  ]])

        assert (phi - real_topics).abs().values.max() < tolerance 
    finally:
        shutil.rmtree(batches_folder)
예제 #7
0
def test_func():
    data_path = os.environ.get('BIGARTM_UNITTEST_DATA')
    batches_folder = tempfile.mkdtemp()
    dump_folder = tempfile.mkdtemp()

    try:
        batch_vectorizer = artm.BatchVectorizer(data_path=data_path,
                                                data_format='bow_uci',
                                                collection_name='kos',
                                                target_folder=batches_folder)

        model_1 = artm.ARTM(num_processors=7,
                            cache_theta=True,
                            num_document_passes=5,
                            reuse_theta=True,
                            seed=10,
                            num_topics=15,
                            class_ids={'@default_class': 1.0},
                            theta_name='THETA',
                            dictionary=batch_vectorizer.dictionary)

        model_2 = artm.ARTM(num_processors=7,
                            cache_theta=False,
                            num_document_passes=5,
                            reuse_theta=False,
                            seed=10,
                            num_topics=15,
                            class_ids={'@default_class': 1.0},
                            dictionary=batch_vectorizer.dictionary)

        for model in [model_1, model_2]:
            model.scores.add(
                artm.PerplexityScore(name='perp',
                                     dictionary=batch_vectorizer.dictionary))
            model.scores.add(artm.SparsityThetaScore(name='sp_theta', eps=0.1))
            model.scores.add(artm.TopTokensScore(name='top_tok',
                                                 num_tokens=10))
            model.scores.add(
                artm.SparsityPhiScore(name='sp_nwt',
                                      model_name=model.model_nwt))
            model.scores.add(
                artm.TopicKernelScore(name='kernel',
                                      topic_names=model.topic_names[0:5],
                                      probability_mass_threshold=0.4))

            topic_pairs = {}
            for topic_name_1 in model.topic_names:
                for topic_name_2 in model.topic_names:
                    if topic_name_1 not in topic_pairs:
                        topic_pairs[topic_name_1] = {}
                    topic_pairs[topic_name_1][
                        topic_name_2] = numpy.random.randint(0, 3)

            model.regularizers.add(
                artm.DecorrelatorPhiRegularizer(name='decor',
                                                tau=100000.0,
                                                topic_pairs=topic_pairs))
            model.regularizers.add(
                artm.SmoothSparsePhiRegularizer(
                    name='smsp_phi',
                    tau=-0.5,
                    gamma=0.3,
                    dictionary=batch_vectorizer.dictionary))
            model.regularizers.add(
                artm.SmoothSparseThetaRegularizer(name='smsp_theta',
                                                  tau=0.1,
                                                  doc_topic_coef=[2.0] *
                                                  model.num_topics))
            model.regularizers.add(
                artm.SmoothPtdwRegularizer(name='sm_ptdw', tau=0.1))

            # learn first model and dump it on disc
            model.fit_offline(batch_vectorizer, num_collection_passes=10)
            model.fit_online(batch_vectorizer, update_every=1)

            model.dump_artm_model(os.path.join(dump_folder, 'target'))

            params = {}
            with open(os.path.join(dump_folder, 'target', 'parameters.json'),
                      'r') as fin:
                params = json.load(fin)
            _assert_json_params(params)

            # create second model from the dump and check the results are equal
            model_new = artm.load_artm_model(
                os.path.join(dump_folder, 'target'))

            _assert_params_equality(model, model_new)
            _assert_scores_equality(model, model_new)
            _assert_regularizers_equality(model, model_new)
            _assert_score_values_equality(model, model_new)
            _assert_matrices_equality(model, model_new)

            # continue learning of both models
            model.fit_offline(batch_vectorizer, num_collection_passes=3)
            model.fit_online(batch_vectorizer, update_every=1)

            model_new.fit_offline(batch_vectorizer, num_collection_passes=3)
            model_new.fit_online(batch_vectorizer, update_every=1)

            # check new results are also equal
            _assert_params_equality(model, model_new)
            _assert_scores_equality(model, model_new)
            _assert_regularizers_equality(model, model_new)
            _assert_score_values_equality(model, model_new)
            _assert_matrices_equality(model, model_new)

            shutil.rmtree(os.path.join(dump_folder, 'target'))
    finally:
        shutil.rmtree(batches_folder)
        shutil.rmtree(dump_folder)
예제 #8
0
def test_func():
    num_topics = 5
    tolerance = 0.01
    batches_folder = tempfile.mkdtemp()

    try:
        with open(os.path.join(batches_folder, 'temp.vw.txt'), 'w') as fout:
            fout.write('title_0 aaa:6 bbb:3 ccc:2 |@time_class time_1\n')
            fout.write('title_1 aaa:2 bbb:9 ccc:3\n')
            fout.write('title_2 aaa:1 bbb:2 ccc:7 |@time_class time_2\n')
            fout.write('title_3 aaa:7 bbb:4 ccc:5 |@time_class time_2\n')

        batch_vectorizer = artm.BatchVectorizer(data_path=os.path.join(
            batches_folder, 'temp.vw.txt'),
                                                data_format='vowpal_wabbit',
                                                target_folder=batches_folder)
        # configure model 1
        model = artm.ARTM(num_topics=num_topics,
                          dictionary=batch_vectorizer.dictionary,
                          num_document_passes=1)

        reg = artm.NetPlsaPhiRegularizer(name='net_plsa',
                                         tau=1.0,
                                         class_id='@time_class',
                                         vertex_names=['time_1', 'time_2'],
                                         vertex_weights=[1.0, 2.0],
                                         edge_weights={
                                             0: {
                                                 1: 3.0
                                             },
                                             1: {
                                                 0: 2.0
                                             }
                                         })
        model.regularizers.add(reg)

        # configure model 2
        model_2 = artm.ARTM(num_topics=num_topics,
                            dictionary=batch_vectorizer.dictionary,
                            num_document_passes=1)

        model_2.regularizers.add(
            artm.NetPlsaPhiRegularizer(name='net_plsa', tau=1.0))
        model_2.regularizers['net_plsa'].class_id = '@time_class'
        model_2.regularizers['net_plsa'].vertex_names = ['time_1', 'time_2']
        model_2.regularizers['net_plsa'].vertex_weights = [1.0, 2.0]
        model_2.regularizers['net_plsa'].edge_weights = {
            0: {
                1: 3.0
            },
            1: {
                0: 2.0
            }
        }

        model.fit_offline(batch_vectorizer=batch_vectorizer,
                          num_collection_passes=2)
        model_2.fit_offline(batch_vectorizer=batch_vectorizer,
                            num_collection_passes=2)

        phi = model.get_phi()
        phi_2 = model_2.get_phi()
        assert phi.equals(phi_2)

        model.dump_artm_model(os.path.join(batches_folder, 'target'))
        model_3 = artm.load_artm_model(os.path.join(batches_folder, 'target'))

        model.fit_offline(batch_vectorizer=batch_vectorizer,
                          num_collection_passes=1)
        model_3.fit_offline(batch_vectorizer=batch_vectorizer,
                            num_collection_passes=1)

        phi = model.get_phi()
        phi_3 = model_3.get_phi()
        assert phi.equals(phi_3)

        real_topics = pd.DataFrame(
            data={
                'topic_0':
                dict(ccc=0.098, bbb=0.145, aaa=0.757, time_1=0.06,
                     time_2=0.94),
                'topic_1':
                dict(ccc=0.892, bbb=0.004, aaa=0.104, time_1=0.0, time_2=1.0),
                'topic_2':
                dict(ccc=0.099,
                     bbb=0.618,
                     aaa=0.283,
                     time_1=0.092,
                     time_2=0.908),
                'topic_3':
                dict(ccc=0.389, bbb=0.334, aaa=0.277, time_1=0.0, time_2=1.0),
                'topic_4':
                dict(ccc=0.184, bbb=0.684, aaa=0.132, time_1=0.0, time_2=1.0),
            })

        assert (phi - real_topics).abs().values.max() < tolerance
    finally:
        shutil.rmtree(batches_folder)
예제 #9
0
def test_func():
    data_path = os.environ.get('BIGARTM_UNITTEST_DATA')
    batches_folder = tempfile.mkdtemp()
    dump_folder = tempfile.mkdtemp()

    try:
        batch_vectorizer = artm.BatchVectorizer(data_path=data_path,
                                                data_format='bow_uci',
                                                collection_name='kos',
                                                target_folder=batches_folder)

        model_1 = artm.ARTM(num_processors=7,
                            cache_theta=True,
                            num_document_passes=5,
                            reuse_theta=True,
                            seed=10,
                            num_topics=15,
                            class_ids={'@default_class': 1.0},
                            theta_name='THETA',
                            dictionary=batch_vectorizer.dictionary)

        model_2 = artm.ARTM(num_processors=7,
                            cache_theta=False,
                            num_document_passes=5,
                            reuse_theta=False,
                            seed=10,
                            num_topics=15,
                            class_ids={'@default_class': 1.0},
                            dictionary=batch_vectorizer.dictionary)

        for model in [model_1, model_2]:
            model.scores.add(artm.PerplexityScore(name='perp', dictionary=batch_vectorizer.dictionary))
            model.scores.add(artm.SparsityThetaScore(name='sp_theta', eps=0.1))
            model.scores.add(artm.TopTokensScore(name='top_tok', num_tokens=10))
            model.scores.add(artm.SparsityPhiScore(name='sp_nwt', model_name=model.model_nwt))
            model.scores.add(artm.TopicKernelScore(name='kernel', topic_names=model.topic_names[0: 5],
                                                   probability_mass_threshold=0.4))

            topic_pairs = {}
            for topic_name_1 in model.topic_names:
                for topic_name_2 in model.topic_names:
                    if topic_name_1 not in topic_pairs:
                        topic_pairs[topic_name_1] = {}
                    topic_pairs[topic_name_1][topic_name_2] = numpy.random.randint(0, 3)

            model.regularizers.add(artm.DecorrelatorPhiRegularizer(name='decor', tau=100000.0,
                                                                   topic_pairs=topic_pairs))
            model.regularizers.add(artm.SmoothSparsePhiRegularizer(name='smsp_phi', tau=-0.5, gamma=0.3,
                                                                   dictionary=batch_vectorizer.dictionary))
            model.regularizers.add(artm.SmoothSparseThetaRegularizer(name='smsp_theta', tau=0.1,
                                                                     doc_topic_coef=[2.0] * model.num_topics))
            model.regularizers.add(artm.SmoothPtdwRegularizer(name='sm_ptdw', tau=0.1))

            # learn first model and dump it on disc
            model.fit_offline(batch_vectorizer, num_collection_passes=10)
            model.fit_online(batch_vectorizer, update_every=1)

            model.dump_artm_model(os.path.join(dump_folder, 'target'))

            params = {}
            with open(os.path.join(dump_folder, 'target', 'parameters.json'), 'r') as fin:
                params = json.load(fin)
            _assert_json_params(params)

            # create second model from the dump and check the results are equal
            model_new = artm.load_artm_model(os.path.join(dump_folder, 'target'))

            _assert_params_equality(model, model_new)
            _assert_scores_equality(model, model_new)
            _assert_regularizers_equality(model, model_new)
            _assert_score_values_equality(model, model_new)
            _assert_matrices_equality(model, model_new)
         
            # continue learning of both models
            model.fit_offline(batch_vectorizer, num_collection_passes=3)
            model.fit_online(batch_vectorizer, update_every=1)

            model_new.fit_offline(batch_vectorizer, num_collection_passes=3)
            model_new.fit_online(batch_vectorizer, update_every=1)

            # check new results are also equal
            _assert_params_equality(model, model_new)
            _assert_scores_equality(model, model_new)
            _assert_regularizers_equality(model, model_new)
            _assert_score_values_equality(model, model_new)
            _assert_matrices_equality(model, model_new)

            shutil.rmtree(os.path.join(dump_folder, 'target'))
    finally:
        shutil.rmtree(batches_folder)
        shutil.rmtree(dump_folder)
path = "./configuration/tm_config.json"

with open(path, 'r') as f:
    config_input = json.load(f)

MIN_SCORE = config_input["inference"]["min_similarity_score"]
SIM_NUM = config_input["similarity_number"]
THREADS = config_input["train"]["threads"]
NUM_TOPICS = config_input["train"]["num_topics"]

CLEAR_LOG_MODE = config_input["inference"]["clear_log_dir"]
lc = artm.messages.ConfigureLoggingArgs()
lc.log_dir = "./logs/artm_logs"
lib = artm.wrapper.LibArtm(logging_config=lc)

# Change any other logging parameters at runtime (except logging folder)
lc.minloglevel = 3  # 0 = INFO, 1 = WARNING, 2 = ERROR, 3 = FATAL
lib.ArtmConfigureLogging(lc)
'''lc = artm.messages.ConfigureLoggingArgs()
artm_log_dir = "./logs/artm_logs"
lc.log_dir = artm_log_dir
lib = artm.wrapper.LibArtm(logging_config=lc)
lib.ArtmConfigureLogging(lc)'''
try:
    model = artm.load_artm_model("./opt/tm_model/tm_model_sources")
except Exception as error:
    model = None
    logger.warning(
        "Model is not initialize. Normal for retrain mode. Otherwise - pay attention!"
    )
예제 #11
0
파일: model.py 프로젝트: pacifikus/research
 def load_model(self, path='model_fitted'):
     self.model = artm.load_artm_model(path)
예제 #12
0
    def load(path, experiment=None):
        """
        Loads the model.

        Parameters
        ----------
        path : str
            path to the model's folder
        experiment : Experiment

        Returns
        -------
        TopicModel

        """

        if "model" in os.listdir(f"{path}"):
            model = artm.load_artm_model(f"{path}/model")
        else:
            model = None
            print("There is no dumped model. You should train it again.")

        with open(f"{path}/params.json", "r", encoding='utf-8') as params_f:
            params = json.load(params_f)

        topic_model = TopicModel(model, **params)
        topic_model.experiment = experiment

        custom_scores = {}

        for score_path in glob.glob(os.path.join(path, '*.p')):
            score_name = os.path.basename(score_path).split('.')[0]
            with open(score_path, 'rb') as score_f:
                custom_scores[score_name] = dill.load(score_f)

        topic_model.custom_scores = custom_scores

        custom_regularizers = {}

        for regularizer_path in glob.glob(os.path.join(path, '*.rd')):
            regularizer_name = os.path.basename(regularizer_path).split('.')[0]
            with open(regularizer_path, 'rb') as reg_f:
                custom_regularizers[regularizer_name] = dill.load(reg_f)

        for regularizer_path in glob.glob(os.path.join(path, '*.rp')):
            regularizer_name = os.path.basename(regularizer_path).split('.')[0]
            with open(regularizer_path, 'rb') as reg_f:
                custom_regularizers[regularizer_name] = pickle.load(reg_f)

        topic_model.custom_regularizers = custom_regularizers

        all_agents = glob.glob(os.path.join(path, 'callback*.pkl'))
        topic_model.callbacks = [None for _ in enumerate(all_agents)]

        for agent_path in all_agents:
            filename = os.path.basename(agent_path).split('.')[0]
            original_index = int(filename.partition("_")[2])
            with open(agent_path, 'rb') as agent_f:
                topic_model.callbacks[original_index] = dill.load(agent_f)

        topic_model._reset_score_caches()

        return topic_model
예제 #13
0
    def load(path, experiment=None):
        """
        Loads the model.

        Parameters
        ----------
        path : str
            path to the model's folder
        experiment : Experiment

        Returns
        -------
        TopicModel

        """
        if "model" in os.listdir(f"{path}"):
            model = artm.load_artm_model(f"{path}/model")
        else:
            model = None
            print("There is no dumped model. You should train it again.")

        with open(os.path.join(path, 'params.json'), 'r', encoding='utf-8') as params_file:
            params = json.load(params_file)

        topic_model = TopicModel(model, **params)
        topic_model.experiment = experiment

        custom_scores = {}

        for score_path in glob.glob(os.path.join(path, '*.p')):
            # TODO: file '..p' is not included, so score with name '.' will be lost
            #  Need to validate score name?
            score_file_name = os.path.basename(score_path)
            score_name = os.path.splitext(score_file_name)[0]

            with open(score_path, 'rb') as score_file:
                custom_scores[score_name] = dill.load(score_file)

        topic_model.custom_scores = custom_scores

        custom_regularizers = {}

        for reg_file_extension, loader in zip(['.rd', '.rp'], [dill, pickle]):
            for regularizer_path in glob.glob(os.path.join(path, f'*{reg_file_extension}')):
                regularizer_file_name = os.path.basename(regularizer_path)
                regularizer_name = os.path.splitext(regularizer_file_name)[0]

                with open(regularizer_path, 'rb') as reg_file:
                    custom_regularizers[regularizer_name] = loader.load(reg_file)

        topic_model.custom_regularizers = custom_regularizers

        all_agents = glob.glob(os.path.join(path, 'callback*.pkl'))
        topic_model.callbacks = [None for _ in enumerate(all_agents)]

        for agent_path in all_agents:
            file_name = os.path.basename(agent_path).split('.')[0]
            original_index = int(file_name.partition("_")[2])

            with open(agent_path, 'rb') as agent_file:
                topic_model.callbacks[original_index] = dill.load(agent_file)

        topic_model._reset_score_caches()

        return topic_model