示例#1
0
文件: train.py 项目: hallamlab/reMap
def __train(arg):
    # Setup the number of operations to employ
    steps = 1
    # Whether to display parameters at every operation
    display_params = True

    ##########################################################################################################
    ######################                        PREPROCESSING                         ######################
    ##########################################################################################################

    if arg.define_bags:
        print("\n{0})- Construct bags_labels centroids...".format(steps))
        steps = steps + 1

        # load a hin file
        hin = load_data(file_name=arg.hin_name,
                        load_path=arg.mdpath,
                        tag="heterogeneous information network")
        node2idx_path2vec = dict(
            (node[0], node[1]["mapped_idx"]) for node in hin.nodes(data=True))
        # map pathways indices of vocab to path2vec pathways indices
        vocab = load_data(file_name=arg.vocab_name,
                          load_path=arg.dspath,
                          tag="vocabulary")
        idxvocab = np.array(
            [idx for idx, v in vocab.items() if v in node2idx_path2vec])
        del hin

        # define pathways 2 bags_labels
        phi = np.load(file=os.path.join(arg.mdpath, arg.bag_phi_name))
        phi = phi[phi.files[0]]
        bags_labels = np.argsort(-phi)
        bags_labels = bags_labels[:, :arg.top_k]
        labels_distr_idx = np.array(
            [[pathway for pathway in bag if pathway in idxvocab]
             for bag in bags_labels])
        bags_labels = preprocessing.MultiLabelBinarizer().fit_transform(
            labels_distr_idx)
        labels_distr_idx = [[
            list(idxvocab).index(label_idx) for label_idx in bag_idx
        ] for bag_idx in labels_distr_idx]

        # get trimmed phi distributions
        phi = -np.sort(-phi)
        phi = phi[:, :arg.top_k]

        # calculate correlation
        sigma = np.load(file=os.path.join(arg.mdpath, arg.bag_sigma_name))
        sigma = sigma[sigma.files[0]]
        sigma[sigma < 0] = EPSILON
        C = np.diag(np.sqrt(np.diag(sigma)))
        C_inv = np.linalg.inv(C)
        rho = np.dot(np.dot(C_inv, sigma), C_inv)
        min_rho = np.min(rho)
        max_rho = np.max(rho)
        rho = rho - min_rho
        rho = rho / (max_rho - min_rho)

        # extracting pathway features
        path2vec_features = np.load(
            file=os.path.join(arg.mdpath, arg.features_name))
        path2vec_features = path2vec_features[path2vec_features.files[0]]
        pathways_idx = np.array([
            node2idx_path2vec[v] for idx, v in vocab.items()
            if v in node2idx_path2vec
        ])
        features = path2vec_features[pathways_idx, :]
        features = features / np.linalg.norm(features, axis=1)[:, np.newaxis]

        # get centroids of bags_labels
        C = np.dot(bags_labels, features) / \
            np.sum(bags_labels, axis=1)[:, np.newaxis]
        C = arg.alpha * C

        # save files
        np.savez(os.path.join(arg.dspath, arg.file_name + "_exp_phi_trim.npz"),
                 phi)
        np.savez(os.path.join(arg.dspath, arg.file_name + "_rho.npz"), rho)
        np.savez(os.path.join(arg.dspath, arg.file_name + "_features.npz"),
                 features)
        np.savez(os.path.join(arg.dspath, arg.file_name + "_bag_centroid.npz"),
                 C)
        save_data(data=bags_labels,
                  file_name=arg.file_name + "_bag_pathway.pkl",
                  save_path=arg.dspath,
                  tag="bags_labels with associated pathways",
                  mode="wb")
        save_data(data=idxvocab,
                  file_name=arg.file_name + "_idxvocab.pkl",
                  save_path=arg.dspath,
                  tag="pathway ids to pathway features ids",
                  mode="wb")
        save_data(data=labels_distr_idx,
                  file_name=arg.file_name + "_labels_distr_idx.pkl",
                  save_path=arg.dspath,
                  tag="bags labels batch_idx with associated pathways",
                  mode="wb")
        print("\t>> Done...")

    if arg.recover_max_bags:
        print("\n{0})- Recover maximum expected bags_labels...".format(steps))
        steps = steps + 1

        # load files
        features = np.load(file=os.path.join(arg.dspath, arg.file_name +
                                             "_features.npz"))
        features = features[features.files[0]]
        C = np.load(file=os.path.join(arg.dspath, arg.file_name +
                                      "_bag_centroid.npz"))
        C = C[C.files[0]]
        bags_labels = load_data(file_name=arg.file_name + "_bag_pathway.pkl",
                                load_path=arg.dspath,
                                tag="bags_labels with associated pathways")
        idxvocab = load_data(file_name=arg.file_name + "_idxvocab.pkl",
                             load_path=arg.dspath,
                             tag="pathway ids to pathway features ids")
        y = load_data(file_name=arg.y_name, load_path=arg.dspath, tag="y")
        y_Bag = np.zeros((y.shape[0], C.shape[0]), dtype=np.int)

        for s_idx, sample in enumerate(y):
            desc = "\t>> Recovering maximum number of bags_labels: {0:.2f}%...".format(
                ((s_idx + 1) / y.shape[0]) * 100)
            if (s_idx + 1) != y.shape[0]:
                print(desc, end="\r")
            if (s_idx + 1) == y.shape[0]:
                print(desc)
            pathways = np.zeros((len(list(idxvocab), )), dtype=np.int)
            for ptwy_idx in sample.rows[0]:
                if ptwy_idx in idxvocab:
                    pathways[list(idxvocab).index(ptwy_idx)] = 1
            pathways = np.diag(pathways)
            features = pathways @ features
            sample_bag_features = np.dot(bags_labels, features) / np.sum(
                bags_labels, axis=1)[:, np.newaxis]
            sample_bag_features = arg.alpha * sample_bag_features
            np.nan_to_num(sample_bag_features, copy=False)
            cos = cosine_distances(C, sample_bag_features) / 2
            cos = np.diag(cos)
            B_idx = np.argwhere(cos > arg.v_cos)
            B_idx = B_idx.reshape((B_idx.shape[0], ))
            y_Bag[s_idx, B_idx] = 1

        # save dataset with maximum bags_labels
        save_data(data=lil_matrix(y_Bag),
                  file_name=arg.file_name + "_B.pkl",
                  save_path=arg.dspath,
                  mode="wb",
                  tag="bags to labels data")
        print("\t>> Done...")

    ##########################################################################################################
    ######################                            TRAIN                             ######################
    ##########################################################################################################

    if arg.train:
        print("\n{0})- Training {1} dataset using reMap model...".format(
            steps, arg.y_name))
        steps = steps + 1

        # load files
        print("\t>> Loading files...")
        y_Bag = load_data(file_name=arg.yB_name, load_path=arg.dspath, tag="B")

        # set randomly bags
        if arg.random_allocation:
            num_samples = y_Bag.shape[0]
            y_Bag = y_Bag.toarray()
            for bag_idx in np.arange(y_Bag.shape[1]):
                if np.sum(y_Bag[:, bag_idx]) == num_samples:
                    y_Bag[:, bag_idx] = np.random.binomial(
                        1, arg.theta_bern, num_samples)
            y_Bag[y_Bag == 0] = -1
            y_Bag = lil_matrix(y_Bag)
            # save dataset with maximum bags_labels
            save_data(data=lil_matrix(y_Bag),
                      file_name=arg.model_name + "_B.pkl",
                      save_path=arg.dspath,
                      mode="wb",
                      tag="bags to labels data")
        else:
            features = np.load(
                file=os.path.join(arg.dspath, arg.features_name))
            features = features[features.files[0]]
            C = np.load(file=os.path.join(arg.dspath, arg.bag_centroid_name))
            C = C[C.files[0]]
            rho = np.load(file=os.path.join(arg.dspath, arg.rho_name))
            rho = rho[rho.files[0]]
            bags_labels = load_data(file_name=arg.bags_labels,
                                    load_path=arg.dspath,
                                    tag="bags_labels with associated pathways")
            X = load_data(file_name=arg.X_name, load_path=arg.dspath, tag="X")
            y = load_data(file_name=arg.y_name, load_path=arg.dspath, tag="y")
            model = reMap(alpha=arg.alpha,
                          binarize_input_feature=arg.binarize_input_feature,
                          fit_intercept=arg.fit_intercept,
                          decision_threshold=arg.decision_threshold,
                          learning_type=arg.learning_type,
                          lr=arg.lr,
                          lr0=arg.lr0,
                          forgetting_rate=arg.forgetting_rate,
                          delay_factor=arg.delay_factor,
                          max_sampling=arg.max_sampling,
                          subsample_input_size=arg.ssample_input_size,
                          subsample_labels_size=arg.ssample_label_size,
                          cost_subsample_size=arg.calc_subsample_size,
                          min_bags=arg.min_bags,
                          max_bags=arg.max_bags,
                          score_strategy=arg.score_strategy,
                          loss_threshold=arg.loss_threshold,
                          early_stop=arg.early_stop,
                          pi=arg.pi,
                          calc_bag_cost=arg.calc_bag_cost,
                          calc_label_cost=arg.calc_label_cost,
                          calc_total_cost=arg.calc_total_cost,
                          varomega=arg.varomega,
                          varrho=arg.varrho,
                          min_negatives_ratio=arg.min_negatives_ratio,
                          lambdas=arg.lambdas,
                          label_bag_sim=arg.label_bag_sim,
                          label_closeness_sim=arg.label_closeness_sim,
                          corr_bag_sim=arg.corr_bag_sim,
                          corr_label_sim=arg.corr_label_sim,
                          corr_input_sim=arg.corr_input_sim,
                          batch=arg.batch,
                          num_epochs=arg.num_epochs,
                          num_jobs=arg.num_jobs,
                          display_interval=arg.display_interval,
                          shuffle=arg.shuffle,
                          random_state=arg.random_state,
                          log_path=arg.logpath)
            model.fit(X=X,
                      y=y,
                      y_Bag=y_Bag,
                      bags_labels=bags_labels,
                      bags_correlation=rho,
                      label_features=features,
                      centroids=C,
                      model_name=arg.model_name,
                      model_path=arg.mdpath,
                      result_path=arg.rspath,
                      snapshot_history=arg.snapshot_history,
                      display_params=display_params)

    ##########################################################################################################
    ######################                           TRANSFORM                          ######################
    ##########################################################################################################

    if arg.transform:
        print("\n{0})- Predicting dataset using a pre-trained reMap model...".
              format(steps))

        # load files
        print("\t>> Loading files...")
        features = np.load(file=os.path.join(arg.dspath, arg.features_name))
        features = features[features.files[0]]
        C = np.load(file=os.path.join(arg.dspath, arg.bag_centroid_name))
        C = C[C.files[0]]
        rho = np.load(file=os.path.join(arg.dspath, arg.rho_name))
        rho = rho[rho.files[0]]
        bags_labels = load_data(file_name=arg.bags_labels,
                                load_path=arg.dspath,
                                tag="bags_labels with associated pathways")

        # load data
        X = load_data(file_name=arg.X_name, load_path=arg.dspath, tag="X")
        y = load_data(file_name=arg.y_name, load_path=arg.dspath, tag="y")
        model = load_data(file_name=arg.model_name + ".pkl",
                          load_path=arg.mdpath,
                          tag="reMap model")

        print("\t>> Predict bags...")
        y_Bag = model.transform(X=X,
                                y=y,
                                bags_labels=bags_labels,
                                bags_correlation=rho,
                                label_features=features,
                                centroids=C,
                                subsample_labels_size=arg.ssample_label_size,
                                max_sampling=arg.max_sampling,
                                snapshot_history=arg.snapshot_history,
                                decision_threshold=arg.decision_threshold,
                                batch_size=arg.batch,
                                num_jobs=arg.num_jobs,
                                file_name=arg.file_name,
                                result_path=arg.rspath)
        # save dataset with maximum bags_labels
        save_data(data=lil_matrix(y_Bag),
                  file_name=arg.file_name + "_B.pkl",
                  save_path=arg.dspath,
                  mode="wb",
                  tag="bags to labels data")
示例#2
0
def __train(arg):
    '''
    Create training objData by calling the Data class
    '''

    # Setup the number of operations to employ
    steps = 1

    ##########################################################################################################
    #########################       PREPROCESS GRAPHS AND GENERTE RANDOM WALKS       #########################
    ##########################################################################################################

    # Whether to display parameters at every operation
    display_params = True

    if arg.preprocess_dataset:
        print('\n{0})- Preprocess graph used for training and evaluating...'.
              format(steps))
        steps = steps + 1
        model = MetaPathGraph(
            first_graph_is_directed=arg.first_graph_is_directed,
            first_graph_is_connected=arg.first_graph_is_connected,
            second_graph_is_directed=arg.second_graph_is_directed,
            second_graph_is_connected=arg.second_graph_is_connected,
            third_graph_is_directed=arg.third_graph_is_directed,
            third_graph_is_connected=arg.third_graph_is_connected,
            weighted_within_layers=arg.weighted_within_layers,
            remove_isolates=arg.remove_isolates,
            q=arg.q,
            num_walks=arg.num_walks,
            walk_length=arg.walk_length,
            learning_rate=arg.learning_rate,
            num_jobs=arg.num_jobs,
            display_interval=arg.display_interval)
        print('\t>> Loading graphs and incidence matrices...')
        first_graph = load_data(file_name=arg.first_graph_name,
                                load_path=arg.ospath,
                                tag='first graph')
        second_graph = load_data(file_name=arg.second_graph_name,
                                 load_path=arg.ospath,
                                 tag='second graph')
        first_mapping_file = load_data(
            file_name=arg.first_mapping_file_name,
            load_path=arg.ospath,
            tag='incidence matrix of first to second graphs')
        third_graph = None
        second_mapping_file = None
        if arg.include_third_graph:
            third_graph = load_data(file_name=arg.third_graph_name,
                                    load_path=arg.ospath,
                                    tag='third graph')
            second_mapping_file = load_data(
                file_name=arg.second_mapping_file_name,
                load_path=arg.ospath,
                tag='incidence matrix of second to  third graphs')
        model.parse_graph_to_hin(first_graph=first_graph,
                                 second_graph=second_graph,
                                 third_graph=third_graph,
                                 first_mapping_file=first_mapping_file,
                                 second_mapping_file=second_mapping_file,
                                 hin_file=arg.hin_file,
                                 ospath=arg.ospath,
                                 display_params=display_params)
        display_params = False

    if arg.extract_instance:
        print('\n{0})- Extract walks...'.format(steps))
        steps = steps + 1
        print('\t>> Loading the heterogeneous information network...')
        hin = load_data(file_name=arg.hin_file,
                        load_path=arg.ospath,
                        tag='heterogeneous information network')
        model = MetaPathGraph(
            first_graph_is_directed=arg.first_graph_is_directed,
            first_graph_is_connected=arg.first_graph_is_connected,
            second_graph_is_directed=arg.second_graph_is_directed,
            second_graph_is_connected=arg.second_graph_is_connected,
            third_graph_is_directed=arg.third_graph_is_directed,
            third_graph_is_connected=arg.third_graph_is_connected,
            weighted_within_layers=arg.weighted_within_layers,
            remove_isolates=arg.remove_isolates,
            q=arg.q,
            num_walks=arg.num_walks,
            walk_length=arg.walk_length,
            learning_rate=arg.learning_rate,
            num_jobs=arg.num_jobs,
            display_interval=arg.display_interval)
        model.generate_walks(constraint_type=arg.constraint_type,
                             just_type=arg.just_type,
                             just_memory_size=arg.just_memory_size,
                             use_metapath_scheme=arg.use_metapath_scheme,
                             metapath_scheme=arg.metapath_scheme,
                             burn_in_phase=arg.burn_in_phase,
                             burn_in_input_size=arg.burn_in_input_size,
                             hin=hin,
                             save_file_name=arg.file_name,
                             ospath=arg.ospath,
                             dspath=arg.dspath,
                             display_params=display_params)

    ##########################################################################################################
    ######################                  PATHWAY2VEC MODEL                  ###############################
    ##########################################################################################################

    # Whether to display parameters at every operation
    display_params = True

    if arg.train:
        print('\n{0})- Training dataset using pathway2vec model...'.format(
            steps))
        print(
            '\t>> Loading the heterogeneous information network and training samples...'
        )
        hin = load_data(file_name=arg.hin_file,
                        load_path=arg.ospath,
                        tag='heterogeneous information network')
        X = load_data(file_name=arg.file_name,
                      load_path=arg.dspath,
                      mode='r',
                      tag='dataset')
        X = [sample.strip().split('\t') for sample in X]
        index2type = dict((val, key)
                          for key, list_val in hin.type2index.items()
                          for val in list_val)
        model = path2vec(
            node_size=hin.number_of_nodes(),
            window_size=arg.window_size,
            num_skips=arg.num_skips,
            num_negative_samples=arg.negative_samples,
            embedding_dimension=arg.embedding_dim,
            use_truncated_normal_weight=arg.use_truncated_normal_weight,
            use_truncated_normal_emb=arg.use_truncated_normal_emb,
            constraint_type=arg.constraint_type,
            learning_rate=arg.learning_rate,
            num_models=arg.max_keep_model,
            subsample_size=arg.subsample_size,
            batch=arg.batch,
            num_epochs=arg.num_epochs,
            max_inner_iter=arg.max_inner_iter,
            num_jobs=arg.num_jobs,
            shuffle=arg.shuffle,
            display_interval=arg.display_interval,
            random_state=arg.random_state)
        node_probability = dict((node[1]['mapped_idx'], node[1]['weight'].data)
                                for node in hin.nodes(data=True))
        model.fit(X=X,
                  node_probability=node_probability,
                  index2type=index2type,
                  type2index=hin.type2index,
                  type2prob=hin.type2prob,
                  fit_by_word2vec=arg.fit_by_word2vec,
                  model_name=arg.model_name,
                  model_path=arg.mdpath,
                  result_path=arg.rspath,
                  display_params=display_params)
示例#3
0
文件: train.py 项目: hallamlab/leADS
def __train(arg):
    # Setup the number of operations to employ
    steps = 1
    # Whether to display parameters at every operation
    display_params = True

    ##########################################################################################################
    ######################                  PREPROCESSING DATASET                       ######################
    ##########################################################################################################

    if arg.preprocess_dataset:
        print('\n{0})- Preprocess dataset...'.format(steps))
        steps = steps + 1
        print('\t>> Loading files...')
        X = load_data(file_name=arg.X_name, load_path=arg.dspath, tag="instances")
        X = X[:, :arg.cutting_point]

        # load a biocyc file
        data_object = load_data(file_name=arg.object_name, load_path=arg.ospath, tag='the biocyc object')
        ec_dict = data_object["ec_id"]
        pathway_dict = data_object["pathway_id"]
        del data_object

        pathway_dict = dict((idx, id) for id, idx in pathway_dict.items())
        ec_dict = dict((idx, id) for id, idx in ec_dict.items())
        labels_components = load_data(file_name=arg.pathway2ec_name, load_path=arg.ospath, tag='M')
        print('\t>> Loading label to component mapping file object...')
        pathway2ec_idx = load_data(file_name=arg.pathway2ec_idx_name, load_path=arg.ospath, print_tag=False)
        pathway2ec_idx = list(pathway2ec_idx)
        tmp = list(ec_dict.keys())
        ec_dict = dict((idx, ec_dict[tmp.index(ec)]) for idx, ec in enumerate(pathway2ec_idx))

        # load path2vec features
        path2vec_features = np.load(file=os.path.join(arg.ospath, arg.features_name))

        # load a hin file
        hin = load_data(file_name=arg.hin_name, load_path=arg.ospath, tag='heterogeneous information network')
        # get pathway2ec mapping
        node2idx_pathway2ec = [node[0] for node in hin.nodes(data=True)]
        del hin

        __build_features(X=X, pathwat_dict=pathway_dict, ec_dict=ec_dict, labels_components=labels_components,
                         node2idx_pathway2ec=node2idx_pathway2ec,
                         path2vec_features=path2vec_features, file_name=arg.file_name, dspath=arg.dspath,
                         batch_size=arg.batch, num_jobs=arg.num_jobs)

    ##########################################################################################################
    ######################                            TRAIN                             ######################
    ##########################################################################################################

    if arg.train:
        print(
            '\n{0})- Training {1} dataset using leADS model...'.format(steps, arg.X_name))
        steps = steps + 1

        # load files
        print('\t>> Loading files...')
        X = load_data(file_name=arg.X_name, load_path=arg.dspath, tag="X")
        y = load_data(file_name=arg.y_name, load_path=arg.dspath, tag="y")
        y_Bags = None
        bags_labels = None
        label_features = None
        centroids = None

        if not arg.train_labels:
            y_Bags = load_data(file_name=arg.yB_name, load_path=arg.dspath, tag="B")
            bags_labels = load_data(file_name=arg.bags_labels, load_path=arg.ospath,
                                    tag="bags_labels with associated pathways")
            label_features = load_data(file_name=arg.features_name, load_path=arg.ospath, tag="features")
            centroids = np.load(file=os.path.join(arg.ospath, arg.centroids))
            centroids = centroids[centroids.files[0]]

        A = None
        if arg.fuse_weight:
            A = load_item_features(file_name=os.path.join(arg.ospath, arg.similarity_name), use_components=False)
        if arg.train_selected_sample:
            if os.path.exists(os.path.join(arg.rspath, arg.samples_ids)):
                sample_ids = load_data(file_name=arg.samples_ids, load_path=arg.rspath, tag="selected samples")
                sample_ids = np.array(sample_ids)
                X = X[sample_ids, :]
                y = y[sample_ids, :]
                if not arg.train_labels:
                    y_Bags = y_Bags[sample_ids, :]
            else:
                print('\t\t No sample ids file is provided...')

        model = leADS(alpha=arg.alpha, binarize_input_feature=arg.binarize_input_feature,
                      normalize_input_feature=arg.normalize_input_feature,
                      use_external_features=arg.use_external_features,
                      cutting_point=arg.cutting_point, fit_intercept=arg.fit_intercept,
                      decision_threshold=arg.decision_threshold, subsample_input_size=arg.ssample_input_size,
                      subsample_labels_size=arg.ssample_label_size, calc_ads=arg.calc_ads,
                      acquisition_type=arg.acquisition_type, top_k=arg.top_k, ads_percent=arg.ads_percent,
                      advanced_subsampling=arg.advanced_subsampling, tol_labels_iter=arg.tol_labels_iter,
                      cost_subsample_size=arg.calc_subsample_size, calc_label_cost=arg.calc_label_cost,
                      calc_bag_cost=arg.calc_bag_cost, calc_total_cost=arg.calc_total_cost,
                      label_uncertainty_type=arg.label_uncertainty_type, label_bag_sim=arg.label_bag_sim,
                      label_closeness_sim=arg.label_closeness_sim, corr_bag_sim=arg.corr_bag_sim,
                      corr_label_sim=arg.corr_label_sim, corr_input_sim=arg.corr_input_sim, penalty=arg.penalty,
                      alpha_elastic=arg.alpha_elastic, l1_ratio=arg.l1_ratio, sigma=arg.sigma,
                      fuse_weight=arg.fuse_weight, lambdas=arg.lambdas, loss_threshold=arg.loss_threshold,
                      early_stop=arg.early_stop, learning_type=arg.learning_type, lr=arg.lr, lr0=arg.lr0,
                      delay_factor=arg.delay_factor, forgetting_rate=arg.forgetting_rate, num_models=arg.num_models,
                      batch=arg.batch, max_inner_iter=arg.max_inner_iter, num_epochs=arg.num_epochs,
                      num_jobs=arg.num_jobs, display_interval=arg.display_interval, shuffle=arg.shuffle,
                      random_state=arg.random_state, log_path=arg.logpath)
        model.fit(X=X, y=y, y_Bag=y_Bags, bags_labels=bags_labels, label_features=label_features, centroids=centroids,
                  A=A, model_name=arg.model_name, model_path=arg.mdpath, result_path=arg.rspath,
                  display_params=display_params)

    ##########################################################################################################
    ######################                           EVALUATE                           ######################
    ##########################################################################################################

    if arg.evaluate:
        print('\n{0})- Evaluating leADS model...'.format(steps))
        steps = steps + 1

        # load files
        print('\t>> Loading files...')
        X = load_data(file_name=arg.X_name, load_path=arg.dspath, tag="X")
        bags_labels = None
        label_features = None
        centroids = None
        if not arg.pred_bags:
            y = load_data(file_name=arg.y_name, load_path=arg.dspath, tag="y")
        if arg.pred_bags:
            y_Bags = load_data(file_name=arg.yB_name, load_path=arg.dspath, tag="B")

        # load model
        model = load_data(file_name=arg.model_name + '.pkl', load_path=arg.mdpath, tag='leADS')

        if model.learn_bags:
            bags_labels = load_data(file_name=arg.bags_labels, load_path=arg.dspath,
                                    tag="bags_labels with associated pathways")
        if model.label_uncertainty_type == "dependent":
            label_features = load_data(file_name=arg.features_name, load_path=arg.dspath, tag="features")
            centroids = np.load(file=os.path.join(arg.dspath, arg.centroids))
            centroids = centroids[centroids.files[0]]

        # labels prediction score
        y_pred_Bags, y_pred = model.predict(X=X, bags_labels=bags_labels, label_features=label_features,
                                            centroids=centroids,
                                            estimate_prob=arg.estimate_prob, pred_bags=arg.pred_bags,
                                            pred_labels=arg.pred_labels,
                                            build_up=arg.build_up, pref_rank=arg.pref_rank, top_k_rank=arg.top_k_rank,
                                            subsample_labels_size=arg.ssample_label_size, soft_voting=arg.soft_voting,
                                            apply_t_criterion=arg.apply_tcriterion, adaptive_beta=arg.adaptive_beta,
                                            decision_threshold=arg.decision_threshold, batch_size=arg.batch,
                                            num_jobs=arg.num_jobs)

        file_name = arg.file_name + '_scores.txt'
        if arg.pred_bags:
            score(y_true=y_Bags.toarray(), y_pred=y_pred_Bags.toarray(), item_lst=['biocyc_bags'],
                  six_db=False, top_k=arg.top_k, mode='a', file_name=file_name, save_path=arg.rspath)
        if arg.pred_labels:
            if arg.dsname == 'golden':
                score(y_true=y.toarray(), y_pred=y_pred.toarray(), item_lst=[arg.dsname], six_db=True,
                      top_k=arg.top_k, mode='a', file_name=file_name, save_path=arg.rspath)
            else:
                score(y_true=y.toarray(), y_pred=y_pred.toarray(), item_lst=[arg.dsname], six_db=False,
                      top_k=arg.top_k, mode='a', file_name=file_name, save_path=arg.rspath)

    ##########################################################################################################
    ######################                            PREDICT                           ######################
    ##########################################################################################################

    if arg.predict:
        print('\n{0})- Predicting dataset using a pre-trained leADS model...'.format(steps))
        if arg.pathway_report or arg.extract_pf:
            print('\t>> Loading biocyc object...')
            # load a biocyc file
            data_object = load_data(file_name=arg.object_name, load_path=arg.ospath, tag='the biocyc object',
                                    print_tag=False)
            pathway_dict = data_object["pathway_id"]
            pathway_common_names = dict((pidx, data_object['processed_kb']['metacyc'][5][pid][0][1])
                                        for pid, pidx in pathway_dict.items()
                                        if pid in data_object['processed_kb']['metacyc'][5])
            ec_dict = data_object['ec_id']
            del data_object
            pathway_dict = dict((idx, id) for id, idx in pathway_dict.items())
            ec_dict = dict((idx, id) for id, idx in ec_dict.items())
            labels_components = load_data(file_name=arg.pathway2ec_name, load_path=arg.ospath, tag='M')
            print('\t>> Loading label to component mapping file object...')
            pathway2ec_idx = load_data(file_name=arg.pathway2ec_idx_name, load_path=arg.ospath, print_tag=False)
            pathway2ec_idx = list(pathway2ec_idx)
            tmp = list(ec_dict.keys())
            ec_dict = dict((idx, ec_dict[tmp.index(ec)]) for idx, ec in enumerate(pathway2ec_idx))
            if arg.extract_pf:
                X, sample_ids = parse_files(ec_dict=ec_dict, ds_folder=arg.dsfolder, dspath=arg.dspath,
                                            rspath=arg.rspath, num_jobs=arg.num_jobs)
                print('\t>> Storing X and sample_ids...')
                save_data(data=X, file_name=arg.file_name + '_X.pkl', save_path=arg.dspath,
                          tag='the pf dataset (X)', mode='w+b', print_tag=False)
                save_data(data=sample_ids, file_name=arg.file_name + '_ids.pkl', save_path=arg.dspath,
                          tag='samples ids', mode='w+b', print_tag=False)
                print('\t>> Loading heterogeneous information network file...')
                hin = load_data(file_name=arg.hin_name, load_path=arg.ospath,
                                tag='heterogeneous information network',
                                print_tag=False)
                # get pathway2ec mapping
                node2idx_pathway2ec = [node[0] for node in hin.nodes(data=True)]
                del hin
                print('\t>> Loading path2vec_features file...')
                path2vec_features = np.load(file=os.path.join(arg.ospath, arg.features_name))
                __build_features(X=X, pathwat_dict=pathway_dict, ec_dict=ec_dict,
                                 labels_components=labels_components,
                                 node2idx_pathway2ec=node2idx_pathway2ec,
                                 path2vec_features=path2vec_features,
                                 file_name=arg.file_name, dspath=arg.dspath,
                                 batch_size=arg.batch, num_jobs=arg.num_jobs)

        # load files
        print('\t>> Loading necessary files......')
        X = load_data(file_name=arg.X_name, load_path=arg.dspath, tag="X")
        tmp = lil_matrix.copy(X)
        bags_labels = None
        label_features = None
        centroids = None

        # load model
        model = load_data(file_name=arg.model_name + '.pkl', load_path=arg.mdpath, tag='leADS')

        if model.learn_bags:
            bags_labels = load_data(file_name=arg.bags_labels, load_path=arg.ospath,
                                    tag="bags_labels with associated pathways")
        if model.label_uncertainty_type == "dependent":
            label_features = load_data(file_name=arg.features_name, load_path=arg.ospath, tag="features")
            centroids = np.load(file=os.path.join(arg.ospath, arg.centroids))
            centroids = centroids[centroids.files[0]]

        # predict
        y_pred_Bags, y_pred = model.predict(X=X, bags_labels=bags_labels, label_features=label_features,
                                            centroids=centroids,
                                            estimate_prob=False, pred_bags=arg.pred_bags, pred_labels=arg.pred_labels,
                                            build_up=arg.build_up, pref_rank=arg.pref_rank, top_k_rank=arg.top_k_rank,
                                            subsample_labels_size=arg.ssample_label_size, soft_voting=arg.soft_voting,
                                            apply_t_criterion=arg.apply_tcriterion, adaptive_beta=arg.adaptive_beta,
                                            decision_threshold=arg.decision_threshold, batch_size=arg.batch,
                                            num_jobs=arg.num_jobs)
        # labels prediction score
        y_pred_Bags_score, y_pred_score = model.predict(X=X, bags_labels=bags_labels, label_features=label_features,
                                                        centroids=centroids, estimate_prob=True,
                                                        pred_bags=arg.pred_bags,
                                                        pred_labels=arg.pred_labels, build_up=arg.build_up,
                                                        pref_rank=arg.pref_rank, top_k_rank=arg.top_k_rank,
                                                        subsample_labels_size=arg.ssample_label_size,
                                                        soft_voting=arg.soft_voting,
                                                        apply_t_criterion=arg.apply_tcriterion,
                                                        adaptive_beta=arg.adaptive_beta,
                                                        decision_threshold=arg.decision_threshold,
                                                        batch_size=arg.batch, num_jobs=arg.num_jobs)
        if arg.pathway_report:
            print('\t>> Synthesizing pathway reports...')
            X = tmp
            sample_ids = np.arange(X.shape[0])
            if arg.extract_pf:
                sample_ids = load_data(file_name=arg.file_name + "_ids.pkl", load_path=arg.dspath, tag="samples ids")
            else:
                if arg.samples_ids is not None:
                    if arg.samples_ids in os.listdir(arg.dspath):
                        sample_ids = load_data(file_name=arg.samples_ids, load_path=arg.dspath, tag="samples ids")
            synthesize_report(X=X[:, :arg.cutting_point], sample_ids=sample_ids, y_pred=y_pred, y_dict_ids=pathway_dict,
                              y_common_name=pathway_common_names, component_dict=ec_dict,
                              labels_components=labels_components, y_pred_score=y_pred_score, batch_size=arg.batch,
                              num_jobs=arg.num_jobs, rspath=arg.rspath, dspath=arg.dspath, file_name=arg.file_name)
        else:
            print('\t>> Storing predictions (label index) to: {0:s}'.format(arg.file_name + '_y_leads.pkl'))
            save_data(data=y_pred, file_name=arg.file_name + "_y_leads.pkl", save_path=arg.dspath,
                      mode="wb", print_tag=False)
            if arg.pred_bags:
                print('\t>> Storing predictions (bag index) to: {0:s}'.format(
                    arg.file_name + '_yBags_leads.pkl'))
                save_data(data=y_pred_Bags, file_name=arg.file_name + "_yBags_leads.pkl", save_path=arg.dspath,
                          mode="wb", print_tag=False)
示例#4
0
def __train(arg):
    # Setup the number of operations to employ
    steps = 1
    # Whether to display parameters at every operation
    display_params = True

    if arg.preprocess_dataset:
        print('\n{0})- Preprocessing dataset...'.format(steps))
        steps = steps + 1

        print('\t>> Loading files...')
        # load a biocyc file
        data_object = load_data(file_name=arg.object_name, load_path=arg.ospath, tag='the biocyc object')
        # extract pathway ids
        pathway_dict = data_object["pathway_id"]
        ec_dict = data_object["ec_id"]
        del data_object

        # load a hin file
        hin = load_data(file_name=arg.hin_name, load_path=arg.ospath,
                        tag='heterogeneous information network')
        # get path2vec mapping
        node2idx_path2vec = dict((node[0], node[1]['mapped_idx'])
                                 for node in hin.nodes(data=True))
        # get pathway2ec mapping
        node2idx_pathway2ec = [node[0] for node in hin.nodes(data=True)]
        Adj = nx.adj_matrix(G=hin)
        del hin

        # load pathway2ec mapping matrix
        pathway2ec_idx = load_data(file_name=arg.pathway2ec_idx_name, load_path=arg.ospath)
        path2vec_features = np.load(file=os.path.join(arg.mdpath, arg.features_name))

        # extracting pathway and ec features
        labels_components = load_data(file_name=arg.pathway2ec_name, load_path=arg.ospath, tag='M')
        path2vec_features = path2vec_features[path2vec_features.files[0]]
        pathways_idx = np.array([node2idx_path2vec[v] for v, idx in pathway_dict.items()
                                 if v in node2idx_path2vec])
        P = path2vec_features[pathways_idx, :]
        tmp = [idx for v, idx in ec_dict.items() if v in node2idx_pathway2ec]
        ec_idx = np.array([idx for idx in tmp if len(np.where(pathway2ec_idx == idx)[0]) > 0])
        E = path2vec_features[ec_idx, :]

        # constraint features space between 0 to 1 to avoid negative results
        min_rho = np.min(P)
        max_rho = np.max(P)
        P = P - min_rho
        P = P / (max_rho - min_rho)
        P = P / np.linalg.norm(P, axis=1)[:, np.newaxis]
        min_rho = np.min(E)
        max_rho = np.max(E)
        E = E - min_rho
        E = E / (max_rho - min_rho)
        E = E / np.linalg.norm(E, axis=1)[:, np.newaxis]

        # building A and B matrices
        lil_matrix.setdiag(Adj, 0)
        A = Adj[pathways_idx[:, None], pathways_idx]
        A = A / A.sum(1)
        A = np.nan_to_num(A)
        B = Adj[ec_idx[:, None], ec_idx]
        B = B / B.sum(1)
        B = np.nan_to_num(B)

        ## train size
        if arg.ssample_input_size < 1:
            # add white noise to M
            train_size = labels_components.shape[0] * arg.ssample_input_size
            idx = np.random.choice(a=np.arange(labels_components.shape[0]), size=int(train_size), replace=False)
            labels_components = labels_components.toarray()
            labels_components[idx] = np.zeros((idx.shape[0], labels_components.shape[1]))
        if arg.white_links:
            if arg.ssample_input_size < 1:
                # add white noise to A
                train_size = A.shape[0] * arg.ssample_input_size
                idx = np.random.choice(a=np.arange(A.shape[0]), size=int(train_size), replace=False)
                A = lil_matrix(A).toarray()
                tmp = np.zeros((idx.shape[0], A.shape[0]))
                A[idx] = tmp
                A[:, idx] = tmp.T
                # add white noise to B
                train_size = B.shape[0] * arg.ssample_input_size
                idx = np.random.choice(a=np.arange(B.shape[0]), size=int(train_size), replace=False)
                B = lil_matrix(B).toarray()
                tmp = np.zeros((idx.shape[0], B.shape[0]))
                B[idx] = tmp
                B[:, idx] = tmp.T

        # save files
        print('\t>> Saving files...')
        save_data(data=lil_matrix(labels_components), file_name=arg.M_name, save_path=arg.dspath, tag="M", mode="wb")
        save_data(data=lil_matrix(P), file_name=arg.P_name, save_path=arg.dspath, tag="P", mode="wb")
        save_data(data=lil_matrix(E), file_name=arg.E_name, save_path=arg.dspath, tag="E", mode="wb")
        save_data(data=lil_matrix(A), file_name=arg.A_name, save_path=arg.dspath, tag="A", mode="wb")
        save_data(data=lil_matrix(B), file_name=arg.B_name, save_path=arg.dspath, tag="B", mode="wb")
        print('\t>> Done...')

    ##########################################################################################################
    ######################                     TRAIN USING triUMPF                      ######################
    ##########################################################################################################

    if arg.train:
        print('\n{0})- Training {1} dataset using triUMPF model...'.format(steps, arg.y_name))
        steps = steps + 1

        # load files
        print('\t>> Loading files...')
        labels_components, W, H, P, E, A, B, X, y = None, None, None, None, None, None, None, None, None

        if arg.no_decomposition:
            W = load_data(file_name=arg.W_name, load_path=arg.mdpath, tag='W')
            H = load_data(file_name=arg.H_name, load_path=arg.mdpath, tag='H')
        else:
            labels_components = load_data(file_name=arg.M_name, load_path=arg.dspath, tag='M')
        if arg.fit_features:
            P = load_data(file_name=arg.P_name, load_path=arg.dspath, tag='P')
            E = load_data(file_name=arg.E_name, load_path=arg.dspath, tag='E')
        if arg.fit_comm:
            if not arg.fit_features:
                P = load_data(file_name=arg.P_name, load_path=arg.dspath, tag='P')
                E = load_data(file_name=arg.E_name, load_path=arg.dspath, tag='E')
            X = load_data(file_name=arg.X_name, load_path=arg.dspath, tag='X')
            y = load_data(file_name=arg.y_name, load_path=arg.dspath, tag='X')
            A = load_data(file_name=arg.A_name, load_path=arg.dspath, tag='A')
            B = load_data(file_name=arg.B_name, load_path=arg.dspath, tag='B')

        model = triUMPF(num_components=arg.num_components, num_communities_p=arg.num_communities_p,
                        num_communities_e=arg.num_communities_e, proxy_order_p=arg.proxy_order_p,
                        proxy_order_e=arg.proxy_order_e, mu_omega=arg.mu_omega, mu_gamma=arg.mu_gamma,
                        fit_features=arg.fit_features, fit_comm=arg.fit_comm, fit_pure_comm=arg.fit_pure_comm,
                        normalize_input_feature=arg.normalize_input_feature,
                        binarize_input_feature=arg.binarize_input_feature,
                        use_external_features=arg.use_external_features, cutting_point=arg.cutting_point,
                        fit_intercept=arg.fit_intercept, alpha=arg.alpha, beta=arg.beta, rho=arg.rho,
                        lambdas=arg.lambdas, eps=arg.eps, early_stop=arg.early_stop, penalty=arg.penalty,
                        alpha_elastic=arg.alpha_elastic, l1_ratio=arg.l1_ratio, loss_threshold=arg.loss_threshold,
                        decision_threshold=arg.decision_threshold, subsample_input_size=arg.ssample_input_size,
                        subsample_labels_size=arg.ssample_label_size, learning_type=arg.learning_type, lr=arg.lr,
                        lr0=arg.lr0, delay_factor=arg.delay_factor, forgetting_rate=arg.forgetting_rate,
                        batch=arg.batch, max_inner_iter=arg.max_inner_iter, num_epochs=arg.num_epochs,
                        num_jobs=arg.num_jobs, display_interval=arg.display_interval, shuffle=arg.shuffle,
                        random_state=arg.random_state, log_path=arg.logpath)
        model.fit(M=labels_components, W=W, H=H, X=X, y=y, P=P, E=E, A=A, B=B, model_name=arg.model_name,
                  model_path=arg.mdpath, result_path=arg.rspath, display_params=display_params)

    ##########################################################################################################
    ######################                    PREDICT USING triUMPF                     ######################
    ##########################################################################################################

    if arg.predict:
        print('\n{0})- Predicting using a pre-trained triUMPF model...'.format(steps))
        if arg.pathway_report:
            print('\t>> Loading biocyc object...')
            # load a biocyc file
            data_object = load_data(file_name=arg.object_name, load_path=arg.ospath, tag='the biocyc object',
                                    print_tag=False)
            pathway_dict = data_object["pathway_id"]
            pathway_common_names = dict((pidx, data_object['processed_kb']['metacyc'][5][pid][0][1])
                                        for pid, pidx in pathway_dict.items()
                                        if pid in data_object['processed_kb']['metacyc'][5])
            ec_dict = data_object['ec_id']
            del data_object
            pathway_dict = dict((idx, id) for id, idx in pathway_dict.items())
            ec_dict = dict((idx, id) for id, idx in ec_dict.items())
            labels_components = load_data(file_name=arg.pathway2ec_name, load_path=arg.ospath, tag='M')
            print('\t>> Loading label to component mapping file object...')
            pathway2ec_idx = load_data(file_name=arg.pathway2ec_idx_name, load_path=arg.ospath, print_tag=False)
            pathway2ec_idx = list(pathway2ec_idx)
            tmp = list(ec_dict.keys())
            ec_dict = dict((idx, ec_dict[tmp.index(ec)]) for idx, ec in enumerate(pathway2ec_idx))
            if arg.extract_pf:
                X, sample_ids = parse_files(ec_dict=ec_dict, input_folder=arg.dsfolder, rsfolder=arg.rsfolder,
                                            rspath=arg.rspath, num_jobs=arg.num_jobs)
                print('\t>> Storing X and sample_ids...')
                save_data(data=X, file_name=arg.file_name + '_X.pkl', save_path=arg.dspath,
                          tag='the pf dataset (X)', mode='w+b', print_tag=False)
                save_data(data=sample_ids, file_name=arg.file_name + '_ids.pkl', save_path=arg.dspath,
                          tag='samples ids', mode='w+b', print_tag=False)
                if arg.build_features:
                    # load a hin file
                    print('\t>> Loading heterogeneous information network file...')
                    hin = load_data(file_name=arg.hin_name, load_path=arg.ospath,
                                    tag='heterogeneous information network',
                                    print_tag=False)
                    # get pathway2ec mapping
                    node2idx_pathway2ec = [node[0] for node in hin.nodes(data=True)]
                    del hin
                    print('\t>> Loading path2vec_features file...')
                    path2vec_features = np.load(file=os.path.join(arg.mdpath, arg.features_name))
                    __build_features(X=X, pathwat_dict=pathway_dict, ec_dict=ec_dict,
                                     labels_components=labels_components,
                                     node2idx_pathway2ec=node2idx_pathway2ec,
                                     path2vec_features=path2vec_features,
                                     file_name=arg.file_name, dspath=arg.dspath,
                                     batch_size=arg.batch, num_jobs=arg.num_jobs)
        # load files
        print('\t>> Loading necessary files......')
        X = load_data(file_name=arg.X_name, load_path=arg.dspath, tag="X")
        sample_ids = np.arange(X.shape[0])
        if arg.samples_ids in os.listdir(arg.dspath):
            sample_ids = load_data(file_name=arg.samples_ids, load_path=arg.dspath, tag="samples ids")

        # load model
        model = load_data(file_name=arg.model_name + '.pkl', load_path=arg.mdpath, tag='triUMPF model')

        # predict
        y_pred = model.predict(X=X.toarray(), estimate_prob=False, apply_t_criterion=arg.apply_tcriterion,
                               adaptive_beta=arg.adaptive_beta, decision_threshold=arg.decision_threshold,
                               top_k=arg.top_k, batch_size=arg.batch, num_jobs=arg.num_jobs)
        # labels prediction score
        y_pred_score = model.predict(X=X.toarray(), estimate_prob=True, apply_t_criterion=arg.apply_tcriterion,
                                     adaptive_beta=arg.adaptive_beta, decision_threshold=arg.decision_threshold,
                                     top_k=arg.top_k, batch_size=arg.batch, num_jobs=arg.num_jobs)

        if arg.pathway_report:
            print('\t>> Synthesizing pathway reports...')
            synthesize_report(X=X[:, :arg.cutting_point], sample_ids=sample_ids,
                              y_pred=y_pred, y_dict_ids=pathway_dict, y_common_name=pathway_common_names,
                              component_dict=ec_dict, labels_components=labels_components, y_pred_score=y_pred_score,
                              batch_size=arg.batch, num_jobs=arg.num_jobs, rsfolder=arg.rsfolder, rspath=arg.rspath,
                              dspath=arg.dspath, file_name=arg.file_name + '_triumpf')
        else:
            print('\t>> Storing predictions (label index) to: {0:s}'.format(arg.file_name + '_triumpf_y.pkl'))
            save_data(data=y_pred, file_name=arg.file_name + "_triumpf_y.pkl", save_path=arg.dspath,
                      mode="wb", print_tag=False)
示例#5
0
文件: train.py 项目: hallamlab/cbt
def __train(arg):
    # Setup the number of operations to employ
    steps = 1
    # Whether to display parameters at every operation
    display_params = True

    ##########################################################################################################
    ######################                            TRAIN                             ######################
    ##########################################################################################################

    if arg.train:
        print('\t>> Loading files...')
        dictionary = load_data(file_name=arg.vocab_name,
                               load_path=arg.dspath,
                               tag="dictionary",
                               print_tag=False)
        X = load_data(file_name=arg.X_name,
                      load_path=arg.dspath,
                      tag="X",
                      print_tag=False)
        M = None
        features = None
        if arg.use_supplement:
            M = load_data(file_name=arg.M_name,
                          load_path=arg.dspath,
                          tag="supplementary components")
            M = M.toarray()
        if arg.use_features:
            features = load_data(file_name=arg.features_name,
                                 load_path=arg.dspath,
                                 tag="features")

        if arg.soap:
            print('\n{0})- Training using SOAP model...'.format(steps))
            steps = steps + 1
            model_name = 'soap_' + arg.model_name
            model = SOAP(vocab=dictionary.token2id,
                         num_components=arg.num_components,
                         alpha_mu=arg.alpha_mu,
                         alpha_sigma=arg.alpha_sigma,
                         alpha_phi=arg.alpha_phi,
                         gamma=arg.gamma,
                         kappa=arg.kappa,
                         xi=arg.xi,
                         varpi=arg.varpi,
                         optimization_method=arg.opt_method,
                         cost_threshold=arg.cost_threshold,
                         component_threshold=arg.component_threshold,
                         max_sampling=arg.max_sampling,
                         subsample_input_size=arg.subsample_input_size,
                         batch=arg.batch,
                         num_epochs=arg.num_epochs,
                         max_inner_iter=arg.max_inner_iter,
                         top_k=arg.top_k,
                         collapse2ctm=arg.collapse2ctm,
                         use_features=arg.use_features,
                         num_jobs=arg.num_jobs,
                         display_interval=arg.display_interval,
                         shuffle=arg.shuffle,
                         forgetting_rate=arg.forgetting_rate,
                         delay_factor=arg.delay_factor,
                         random_state=arg.random_state,
                         log_path=arg.logpath)
            model.fit(X=X,
                      M=M,
                      features=features,
                      model_name=model_name,
                      model_path=arg.mdpath,
                      result_path=arg.rspath,
                      display_params=display_params)

        if arg.spreat:
            print('\n{0})- Training using SPREAT model...'.format(steps))
            steps = steps + 1
            model_name = 'spreat_' + arg.model_name
            model = SPREAT(vocab=dictionary.token2id,
                           num_components=arg.num_components,
                           alpha_mu=arg.alpha_mu,
                           alpha_sigma=arg.alpha_sigma,
                           alpha_phi=arg.alpha_phi,
                           gamma=arg.gamma,
                           kappa=arg.kappa,
                           xi=arg.xi,
                           varpi=arg.varpi,
                           optimization_method=arg.opt_method,
                           cost_threshold=arg.cost_threshold,
                           component_threshold=arg.component_threshold,
                           max_sampling=arg.max_sampling,
                           subsample_input_size=arg.subsample_input_size,
                           batch=arg.batch,
                           num_epochs=arg.num_epochs,
                           max_inner_iter=arg.max_inner_iter,
                           top_k=arg.top_k,
                           collapse2ctm=arg.collapse2ctm,
                           use_features=arg.use_features,
                           num_jobs=arg.num_jobs,
                           display_interval=arg.display_interval,
                           shuffle=arg.shuffle,
                           forgetting_rate=arg.forgetting_rate,
                           delay_factor=arg.delay_factor,
                           random_state=arg.random_state,
                           log_path=arg.logpath)
            model.fit(X=X,
                      M=M,
                      features=features,
                      model_name=model_name,
                      model_path=arg.mdpath,
                      result_path=arg.rspath,
                      display_params=display_params)

        if arg.ctm:
            print('\n{0})- Training using CMT model...'.format(steps))
            steps = steps + 1
            model_name = 'ctm_' + arg.model_name
            model = CTM(vocab=dictionary.token2id,
                        num_components=arg.num_components,
                        alpha_mu=arg.alpha_mu,
                        alpha_sigma=arg.alpha_sigma,
                        alpha_beta=arg.alpha_phi,
                        optimization_method=arg.opt_method,
                        cost_threshold=arg.cost_threshold,
                        component_threshold=arg.component_threshold,
                        subsample_input_size=arg.subsample_input_size,
                        batch=arg.batch,
                        num_epochs=arg.num_epochs,
                        max_inner_iter=arg.max_inner_iter,
                        num_jobs=arg.num_jobs,
                        display_interval=arg.display_interval,
                        shuffle=arg.shuffle,
                        forgetting_rate=arg.forgetting_rate,
                        delay_factor=arg.delay_factor,
                        random_state=arg.random_state,
                        log_path=arg.logpath)
            model.fit(X=X,
                      model_name=model_name,
                      model_path=arg.mdpath,
                      result_path=arg.rspath,
                      display_params=display_params)

        if arg.lda:
            print(
                '\n{0})- Training using LDA (sklearn) model...'.format(steps))
            steps = steps + 1
            model_name = 'sklda_' + arg.model_name
            model = skLDA(n_components=arg.num_components,
                          learning_method='batch',
                          learning_decay=arg.delay_factor,
                          learning_offset=arg.forgetting_rate,
                          max_iter=1,
                          batch_size=arg.batch,
                          evaluate_every=arg.display_interval,
                          perp_tol=arg.cost_threshold,
                          mean_change_tol=arg.component_threshold,
                          max_doc_update_iter=arg.max_inner_iter,
                          n_jobs=arg.num_jobs,
                          verbose=0,
                          random_state=arg.random_state)
            print('\t>> Training by LDA model...')
            n_epochs = arg.num_epochs + 1
            old_bound = np.inf
            num_samples = int(X.shape[0] * arg.subsample_input_size)
            list_batches = np.arange(start=0, stop=num_samples, step=arg.batch)
            cost_file_name = model_name + "_cost.txt"
            save_data('',
                      file_name=cost_file_name,
                      save_path=arg.rspath,
                      mode='w',
                      w_string=True,
                      print_tag=False)
            for epoch in np.arange(start=1, stop=n_epochs):
                desc = '\t   {0:d})- Epoch count ({0:d}/{1:d})...'.format(
                    epoch, n_epochs - 1)
                print(desc)
                idx = np.random.choice(X.shape[0], num_samples, False)
                start_epoch = time.time()
                X_tmp = X[idx, :]
                for bidx, batch in enumerate(list_batches):
                    desc = '\t       --> Training: {0:.2f}%...'.format(
                        ((bidx + 1) / len(list_batches)) * 100)
                    if (bidx + 1) != len(list_batches):
                        print(desc, end="\r")
                    if (bidx + 1) == len(list_batches):
                        print(desc)
                    model.partial_fit(X=X_tmp[batch:batch + arg.batch])
                end_epoch = time.time()
                new_bound = -model.score(X=X_tmp) / X.shape[1]
                new_bound = np.log(new_bound)
                print('\t\t  ## Epoch {0} took {1} seconds...'.format(
                    epoch, round(end_epoch - start_epoch, 3)))
                data = str(epoch) + '\t' + str(
                    round(end_epoch - start_epoch,
                          3)) + '\t' + str(new_bound) + '\n'
                save_data(data=data,
                          file_name=cost_file_name,
                          save_path=arg.rspath,
                          mode='a',
                          w_string=True,
                          print_tag=False)
                print('\t\t  --> New cost: {0:.4f}; Old cost: {1:.4f}'.format(
                    new_bound, old_bound))
                if new_bound <= old_bound or epoch == n_epochs - 1:
                    print('\t\t  --> Storing the LDA phi to: {0:s}'.format(
                        model_name + '_phi.npz'))
                    np.savez(os.path.join(arg.mdpath, model_name + '_phi.npz'),
                             model.components_)
                    print(
                        '\t\t  --> Storing the LDA (sklearn) model to: {0:s}'.
                        format(model_name + '.pkl'))
                    save_data(data=model,
                              file_name=model_name + '.pkl',
                              save_path=arg.mdpath,
                              mode="wb",
                              print_tag=False)
                    if epoch == n_epochs - 1:
                        print('\t\t  --> Storing the LDA phi to: {0:s}'.format(
                            model_name + '_phi_final.npz'))
                        np.savez(
                            os.path.join(arg.mdpath,
                                         model_name + '_phi_final.npz'),
                            model.components_)
                        print(
                            '\t\t  --> Storing the LDA (sklearn) model to: {0:s}'
                            .format(model_name + '_final.pkl'))
                        save_data(data=model,
                                  file_name=model_name + '_final.pkl',
                                  save_path=arg.mdpath,
                                  mode="wb",
                                  print_tag=False)
                    old_bound = new_bound
        display_params = False

    ##########################################################################################################
    ######################                           EVALUATE                           ######################
    ##########################################################################################################

    if arg.evaluate:
        print('\t>> Loading files...')
        dictionary = load_data(file_name=arg.vocab_name,
                               load_path=arg.dspath,
                               tag="vocabulary",
                               print_tag=False)
        X = load_data(file_name=arg.X_name,
                      load_path=arg.dspath,
                      tag="X",
                      print_tag=False)
        corpus = load_data(file_name=arg.text_name,
                           load_path=arg.dspath,
                           tag="X (a list of strings)",
                           print_tag=False)
        data = [[dictionary[i] for i, j in item] for item in corpus]

        M = None
        features = None
        if arg.use_supplement:
            M = load_data(file_name=arg.M_name,
                          load_path=arg.dspath,
                          tag="supplementary components")
            M = M.toarray()
        if arg.use_features:
            features = load_data(file_name=arg.features_name,
                                 load_path=arg.dspath,
                                 tag="features")

        if arg.soap:
            print('\n{0})- Evaluating SOAP model...'.format(steps))
            steps = steps + 1
            model_name = 'soap_' + arg.model_name + '.pkl'
            file_name = 'soap_' + arg.model_name + '_score.txt'
            print('\t>> Loading SOAP model...')
            model = load_data(file_name=model_name,
                              load_path=arg.mdpath,
                              tag='SOAP model',
                              print_tag=False)
            score = model.predictive_distribution(X=X,
                                                  M=M,
                                                  features=features,
                                                  cal_average=arg.cal_average,
                                                  batch_size=arg.batch,
                                                  num_jobs=arg.num_jobs)
            print("\t>> Average log predictive score: {0:.4f}".format(score))
            save_data(data="# Average log predictive score: {0:.10f}\n".format(
                score),
                      file_name=file_name,
                      save_path=arg.rspath,
                      tag="log predictive score",
                      mode='w',
                      w_string=True,
                      print_tag=False)
            components = np.argsort(-model.phi)[:, :arg.top_k]
            components = [[dictionary[i] for i in item] for item in components]
            for cr in ['u_mass', 'c_v', 'c_uci', 'c_npmi']:
                cm = CoherenceModel(texts=data,
                                    topics=components,
                                    corpus=corpus,
                                    dictionary=dictionary,
                                    coherence=cr)
                coherence = cm.get_coherence()
                print("\t>> Average coherence ({0}) score: {1:.4f}".format(
                    cr, coherence))
                save_data(
                    data="# Average coherence ({0}) score: {1:.4f}\n".format(
                        cr, coherence),
                    file_name=file_name,
                    save_path=arg.rspath,
                    tag="coherence score",
                    mode='a',
                    w_string=True,
                    print_tag=False)

        if arg.spreat:
            print('\n{0})- Evaluating SPREAT model...'.format(steps))
            steps = steps + 1
            model_name = 'spreat_' + arg.model_name + '.pkl'
            file_name = 'spreat_' + arg.model_name + '_score.txt'
            print('\t>> Loading SPREAT model...')
            model = load_data(file_name=model_name,
                              load_path=arg.mdpath,
                              tag='SPREAT model',
                              print_tag=False)
            score = model.predictive_distribution(X=X,
                                                  M=M,
                                                  features=features,
                                                  cal_average=arg.cal_average,
                                                  batch_size=arg.batch,
                                                  num_jobs=arg.num_jobs)
            print("\t>> Average log predictive score: {0:.4f}".format(score))
            save_data(data="# Average log predictive score: {0:.10f}\n".format(
                score),
                      file_name=file_name,
                      save_path=arg.rspath,
                      tag="log predictive score",
                      mode='w',
                      w_string=True,
                      print_tag=False)
            components = np.argsort(-model.phi)[:, :arg.top_k]
            components = [[dictionary[i] for i in item] for item in components]
            for cr in ['u_mass', 'c_v', 'c_uci', 'c_npmi']:
                cm = CoherenceModel(texts=data,
                                    topics=components,
                                    corpus=corpus,
                                    dictionary=dictionary,
                                    coherence=cr)
                coherence = cm.get_coherence()
                print("\t>> Average coherence ({0}) score: {1:.4f}".format(
                    cr, coherence))
                save_data(
                    data="# Average coherence ({0}) score: {1:.4f}\n".format(
                        cr, coherence),
                    file_name=file_name,
                    save_path=arg.rspath,
                    tag="coherence score",
                    mode='a',
                    w_string=True,
                    print_tag=False)

        if arg.ctm:
            print('\n{0})- Evaluating CTM model...'.format(steps))
            steps = steps + 1
            model_name = 'ctm_' + arg.model_name + '.pkl'
            file_name = 'ctm_' + arg.model_name + '_score.txt'
            print('\t>> Loading CTM model...')
            model = load_data(file_name=model_name,
                              load_path=arg.mdpath,
                              tag='CTM model',
                              print_tag=False)
            score = model.predictive_distribution(X=X,
                                                  cal_average=arg.cal_average,
                                                  batch_size=arg.batch,
                                                  num_jobs=arg.num_jobs)
            print("\t>> Average log predictive score: {0:.4f}".format(score))
            save_data(data="# Average log predictive score: {0:.10f}\n".format(
                score),
                      file_name=file_name,
                      save_path=arg.rspath,
                      tag="log predictive score",
                      mode='w',
                      w_string=True,
                      print_tag=False)
            components = np.argsort(-model.omega)[:, :arg.top_k]
            components = [[dictionary[i] for i in item] for item in components]
            for cr in ['u_mass', 'c_v', 'c_uci', 'c_npmi']:
                cm = CoherenceModel(texts=data,
                                    topics=components,
                                    corpus=corpus,
                                    dictionary=dictionary,
                                    coherence=cr)
                coherence = cm.get_coherence()
                print("\t>> Average coherence ({0}) score: {1:.4f}".format(
                    cr, coherence))
                save_data(
                    data="# Average coherence ({0}) score: {1:.4f}\n".format(
                        cr, coherence),
                    file_name=file_name,
                    save_path=arg.rspath,
                    tag="coherence score",
                    mode='a',
                    w_string=True,
                    print_tag=False)

        if arg.lda:
            print('\n{0})- Evaluating LDA model...'.format(steps))
            steps = steps + 1
            model_name = 'sklda_' + arg.model_name + '.pkl'
            file_name = 'sklda_' + arg.model_name + '_score.txt'
            print('\t>> Loading LDA model...')
            model = load_data(file_name=model_name,
                              load_path=arg.mdpath,
                              tag='LDA model',
                              print_tag=False)
            model.components_ /= model.components_.sum(1)[:, np.newaxis]
            component_distribution = model.transform(X=X)
            score = 0.0
            for idx in np.arange(X.shape[0]):
                feature_idx = X[idx].indices
                temp = np.multiply(component_distribution[idx][:, np.newaxis],
                                   model.components_[:, feature_idx])
                score += np.sum(temp)
            if arg.cal_average:
                score = score / X.shape[0]
            score = np.log(score + np.finfo(np.float).eps)
            print("\t>> Average log predictive score: {0:.4f}".format(score))
            save_data(data="# Average log predictive score: {0:.10f}\n".format(
                score),
                      file_name=file_name,
                      save_path=arg.rspath,
                      tag="log predictive score",
                      mode='w',
                      w_string=True,
                      print_tag=False)
            components = np.argsort(-model.components_)[:, :arg.top_k]
            components = [[dictionary[i] for i in item] for item in components]
            for cr in ['u_mass', 'c_v', 'c_uci', 'c_npmi']:
                cm = CoherenceModel(texts=data,
                                    topics=components,
                                    corpus=corpus,
                                    dictionary=dictionary,
                                    coherence=cr)
                coherence = cm.get_coherence()
                print("\t>> Average coherence ({0}) score: {1:.4f}".format(
                    cr, coherence))
                save_data(
                    data="# Average coherence ({0}) score: {1:.4f}\n".format(
                        cr, coherence),
                    file_name=file_name,
                    save_path=arg.rspath,
                    tag="coherence score",
                    mode='a',
                    w_string=True,
                    print_tag=False)

    ##########################################################################################################
    ######################                           TRANSFORM                          ######################
    ##########################################################################################################

    if arg.transform:
        print('\t>> Loading files...')
        X = load_data(file_name=arg.X_name,
                      load_path=arg.dspath,
                      tag="X",
                      print_tag=False)

        M = None
        features = None
        if arg.use_supplement:
            M = load_data(file_name=arg.M_name,
                          load_path=arg.dspath,
                          tag="supplementary components")
            M = M.toarray()
        if arg.use_features:
            features = load_data(file_name=arg.features_name,
                                 load_path=arg.dspath,
                                 tag="features")

        if arg.soap:
            print('\n{0})- Transforming {1} using a pre-trained SOAP model...'.
                  format(steps, arg.X_name))
            steps = steps + 1
            model_name = 'soap_' + arg.model_name + '.pkl'
            file_name = 'soap_' + arg.file_name + '.pkl'
            print('\t>> Loading SOAP model...')
            model = load_data(file_name=model_name,
                              load_path=arg.mdpath,
                              tag='SOAP model',
                              print_tag=False)
            X = model.transform(X=X,
                                M=M,
                                features=features,
                                batch_size=arg.batch,
                                num_jobs=arg.num_jobs)
            save_data(data=X,
                      file_name=file_name,
                      save_path=arg.dspath,
                      tag="transformed X",
                      mode='wb',
                      print_tag=True)

        if arg.spreat:
            print(
                '\n{0})- Transforming {1} using a pre-trained SPREAT model...'.
                format(steps, arg.X_name))
            steps = steps + 1
            model_name = 'spreat_' + arg.model_name + '.pkl'
            file_name = 'spreat_' + arg.file_name + '.pkl'
            print('\t>> Loading SPREAT model...')
            model = load_data(file_name=model_name,
                              load_path=arg.mdpath,
                              tag='SPREAT model',
                              print_tag=False)
            X = model.transform(X=X,
                                M=M,
                                features=features,
                                batch_size=arg.batch,
                                num_jobs=arg.num_jobs)
            save_data(data=X,
                      file_name=file_name,
                      save_path=arg.dspath,
                      tag="transformed X",
                      mode='wb',
                      print_tag=True)

        if arg.ctm:
            print('\n{0})- Transforming {1} using a pre-trained CTM model...'.
                  format(steps, arg.X_name))
            steps = steps + 1
            model_name = 'ctm_' + arg.model_name + '.pkl'
            file_name = 'ctm_' + arg.file_name + '.pkl'
            print('\t>> Loading CTM model...')
            model = load_data(file_name=model_name,
                              load_path=arg.mdpath,
                              tag='CTM model',
                              print_tag=False)
            X = model.transform(X=X,
                                batch_size=arg.batch,
                                num_jobs=arg.num_jobs)
            save_data(data=X,
                      file_name=file_name,
                      save_path=arg.dspath,
                      tag="transformed X",
                      mode='wb',
                      print_tag=True)

        if arg.lda:
            print('\n{0})- Transforming {1} using a pre-trained LDA model...'.
                  format(steps, arg.X_name))
            steps = steps + 1
            model_name = 'sklda_' + arg.model_name + '.pkl'
            file_name = 'sklda_' + arg.file_name + '.pkl'
            print('\t>> Loading LDA model...')
            model = load_data(file_name=model_name,
                              load_path=arg.mdpath,
                              tag='LDA model',
                              print_tag=False)
            X = model.transform(X=X)
            save_data(data=X,
                      file_name=file_name,
                      save_path=arg.dspath,
                      tag="transformed X",
                      mode='wb',
                      print_tag=True)