コード例 #1
0
def main():
    # parameter set 1
    input_directory = "./nips-abstract"

    input_directory = input_directory.rstrip("/")
    # corpus_name = os.path.basename(input_directory);
    '''
    output_directory = options.output_directory;
    if not os.path.exists(output_directory):
        os.mkdir(output_directory);
    output_directory = os.path.join(output_directory, corpus_name);
    if not os.path.exists(output_directory):
        os.mkdir(output_directory);
    '''

    # Document
    train_docs_path = os.path.join(input_directory, 'train.dat')
    input_doc_stream = open(train_docs_path, 'r')
    train_docs = []
    for line in input_doc_stream:
        train_docs.append(line.strip().lower())
    print("successfully load all training docs from %s..." %
          (os.path.abspath(train_docs_path)))

    # Vocabulary
    vocabulary_path = os.path.join(input_directory, 'voc.dat')
    input_voc_stream = open(vocabulary_path, 'r')
    vocab = []
    for line in input_voc_stream:
        vocab.append(line.strip().lower().split()[0])
    vocab = list(set(vocab))
    print("successfully load all the words from %s..." %
          (os.path.abspath(vocabulary_path)))

    # parameter 2
    number_of_topics = 10
    alpha_mu = 0
    alpha_sigma = 1
    alpha_beta = 1.0 / len(vocab)

    # parameter set 3
    training_iterations = 1

    import variational_bayes
    ctm_inferencer = variational_bayes.VariationalBayes()

    ctm_inferencer._initialize(train_docs, vocab, number_of_topics, alpha_mu,
                               alpha_sigma, alpha_beta)

    for iteration in range(training_iterations):
        clock = time.time()
        log_likelihood = ctm_inferencer.learning()
        clock = time.time() - clock
コード例 #2
0
ファイル: launch_train.py プロジェクト: Sandy4321/PyLDA-1
def main():
    options = parse_args();

    # parameter set 2
    assert(options.number_of_topics > 0);
    number_of_topics = options.number_of_topics;
    assert(options.training_iterations > 0);
    training_iterations = options.training_iterations;

    # parameter set 3
    alpha_alpha = 1.0 / number_of_topics;
    if options.alpha_alpha > 0:
        alpha_alpha = options.alpha_alpha;
    
    # assert options.default_correlation_prior>0;
    # default_correlation_prior = options.default_correlation_prior;
    # assert options.positive_correlation_prior>0;
    # positive_correlation_prior = options.positive_correlation_prior;
    # assert options.negative_correlation_prior>0;
    # negative_correlation_prior = options.negative_correlation_prior;
    
    # parameter set 4
    # disable_alpha_theta_update = options.disable_alpha_theta_update;
    inference_mode = options.inference_mode;
    #update_hyperparameter = options.update_hyperparameter;
    
    # parameter set 5
    assert(options.snapshot_interval > 0);
    if options.snapshot_interval > 0:
        snapshot_interval = options.snapshot_interval;
    
    # parameter set 1
    # assert(options.corpus_name!=None);
    assert(options.input_directory != None);
    assert(options.output_directory != None);
    
    assert(options.tree_name != None);
    tree_name = options.tree_name;
    
    input_directory = options.input_directory;
    input_directory = input_directory.rstrip("/");
    corpus_name = os.path.basename(input_directory);
    
    output_directory = options.output_directory;
    if not os.path.exists(output_directory):
        os.mkdir(output_directory);
    output_directory = os.path.join(output_directory, corpus_name);
    if not os.path.exists(output_directory):
        os.mkdir(output_directory);
    
    # Document
    train_docs_path = os.path.join(input_directory, 'train.dat')
    input_doc_stream = open(train_docs_path, 'r');
    train_docs = [];
    for line in input_doc_stream:
        train_docs.append(line.strip().lower());
    print "successfully load all training docs from %s..." % (os.path.abspath(train_docs_path));
    
    # Vocabulary
    vocabulary_path = os.path.join(input_directory, 'voc.dat');
    input_voc_stream = open(vocabulary_path, 'r');
    vocab = [];
    for line in input_voc_stream:
        vocab.append(line.strip().lower().split()[0]);
    vocab = list(set(vocab));
    print "successfully load all the words from %s..." % (os.path.abspath(vocabulary_path));

    '''
    # create output directory
    now = datetime.datetime.now();
    output_directory += now.strftime("%y%b%d-%H%M%S") + "";
    output_directory += "-prior_tree-K%d-I%d-a%g-S%d-%s-%s-%s/" \
                        % (number_of_topics,
                           training_iterations,
                           alpha_alpha,
                           snapshot_interval,
                           tree_name,
                           inference_mode,
                           update_hyperparameter);
    '''

    # create output directory
    now = datetime.datetime.now();
    suffix = now.strftime("%y%m%d-%H%M%S") + "";
    suffix += "-%s" % ("lda");
    suffix += "-I%d" % (training_iterations);
    suffix += "-S%d" % (snapshot_interval);
    suffix += "-K%d" % (number_of_topics);
    suffix += "-aa%f" % (alpha_alpha);
    #suffix += "-ab%f" % (alpha_beta);
    suffix += "-im%d" % (inference_mode);
    # suffix += "-%s" % (resample_topics);
    # suffix += "-%s" % (hash_oov_words);
    suffix += "/";
    
    output_directory = os.path.join(output_directory, suffix);
    os.mkdir(os.path.abspath(output_directory));

    # store all the options to a file
    options_output_file = open(output_directory + "option.txt", 'w');
    # parameter set 1
    options_output_file.write("input_directory=" + input_directory + "\n");
    options_output_file.write("corpus_name=" + corpus_name + "\n");
    options_output_file.write("tree_name=" + str(tree_name) + "\n");
    # parameter set 2
    options_output_file.write("number_of_iteration=%d\n" % (training_iterations));
    options_output_file.write("number_of_topics=" + str(number_of_topics) + "\n");
    # parameter set 3
    options_output_file.write("alpha_alpha=" + str(alpha_alpha) + "\n");
    # options_output_file.write("default_correlation_prior=" + str(default_correlation_prior) + "\n");
    # options_output_file.write("positive_correlation_prior=" + str(positive_correlation_prior) + "\n");
    # options_output_file.write("negative_correlation_prior=" + str(negative_correlation_prior) + "\n");
    # parameter set 4
    options_output_file.write("inference_mode=%d\n" % (inference_mode));
    #options_output_file.write("update_hyperparameter=%s\n" % (update_hyperparameter));
    # parameter set 5
    #options_output_file.write("snapshot_interval=" + str(snapshot_interval) + "\n");

    options_output_file.close()

    print "========== ========== ========== ========== =========="
    # parameter set 1
    print "output_directory=" + output_directory
    print "input_directory=" + input_directory
    print "corpus_name=" + corpus_name
    print "tree prior file=" + str(tree_name)
    # parameter set 2
    print "training_iterations=%d" % (training_iterations);
    print "number_of_topics=" + str(number_of_topics)
    # parameter set 3
    print "alpha_alpha=" + str(alpha_alpha)
    # print "default_correlation_prior=" + str(default_correlation_prior)
    # print "positive_correlation_prior=" + str(positive_correlation_prior)
    # print "negative_correlation_prior=" + str(negative_correlation_prior)
    # parameter set 4
    print "inference_mode=%d" % (inference_mode)
    #print "update_hyperparameter=%s" % (update_hyperparameter);
    # parameter set 5
    #print "snapshot_interval=" + str(snapshot_interval);
    print "========== ========== ========== ========== =========="

    if inference_mode==0:
        import hybrid;
        lda_inferencer = hybrid.Hybrid();
        #lda_inferencer = hybrid.Hybrid(update_hyperparameter);
        #import hybrid.parse_data as parse_data
    elif inference_mode==1:
        #import monte_carlo
        #lda_inferencer = monte_carlo.MonteCarlo();
        sys.stderr.write("warning: monte carlo inference method is not implemented yet...\n");
        pass
    elif inference_mode==2:
        #from prior.tree import variational_bayes
        #lda_inferencer = variational_bayes.VariationalBayes();
        import variational_bayes
        lda_inferencer = variational_bayes.VariationalBayes();
        #lda_inferencer = variational_bayes.VariationalBayes(update_hyperparameter);
        #from variational_bayes import parse_data
    else:
        sys.stderr.write("error: unrecognized inference mode %d...\n" % (inference_mode));
        return;
    
    # initialize tree
    import priortree
    prior_tree = priortree.VocabTreePrior();
    # from vb.prior.tree.priortree import VocabTreePrior;
    # prior_tree = VocabTreePrior();
    # prior_tree._initialize(input_directory+"tree.wn.*", vocab, default_correlation_prior, positive_correlation_prior, negative_correlation_prior);
    prior_tree._initialize(os.path.join(input_directory, tree_name + ".wn.*"), os.path.join(input_directory, tree_name + ".hyperparams"), vocab)

    lda_inferencer._initialize(train_docs, vocab, prior_tree, number_of_topics, alpha_alpha);
    
    for iteration in xrange(training_iterations):
        lda_inferencer.learning();
        
        if (lda_inferencer._counter % snapshot_interval == 0):
            lda_inferencer.export_beta(os.path.join(output_directory, 'exp_beta-' + str(lda_inferencer._counter)));
            model_snapshot_path = os.path.join(output_directory, 'model-' + str(lda_inferencer._counter));
            cPickle.dump(lda_inferencer, open(model_snapshot_path, 'wb'));
    
    model_snapshot_path = os.path.join(output_directory, 'model-' + str(lda_inferencer._counter));
    cPickle.dump(lda_inferencer, open(model_snapshot_path, 'wb'));
コード例 #3
0
ファイル: launch_train.py プロジェクト: lu839684437/PyLDA
def main():
    options = parse_args()

    # parameter set 2
    assert (options.number_of_topics > 0)
    number_of_topics = options.number_of_topics
    assert (options.training_iterations > 0)
    training_iterations = options.training_iterations
    assert (options.snapshot_interval > 0)
    if options.snapshot_interval > 0:
        snapshot_interval = options.snapshot_interval

    # parameter set 4
    #disable_alpha_theta_update = options.disable_alpha_theta_update;
    inference_mode = options.inference_mode

    # parameter set 1
    #assert(options.corpus_name!=None);
    assert (options.input_directory != None)
    assert (options.output_directory != None)

    input_directory = options.input_directory
    input_directory = input_directory.rstrip("/")
    corpus_name = os.path.basename(input_directory)

    output_directory = options.output_directory
    if not os.path.exists(output_directory):
        os.mkdir(output_directory)
    output_directory = os.path.join(output_directory, corpus_name)
    if not os.path.exists(output_directory):
        os.mkdir(output_directory)

    # Document
    train_docs_path = os.path.join(input_directory, 'train.dat')
    input_doc_stream = open(train_docs_path, 'r')
    train_docs = []
    for line in input_doc_stream:
        train_docs.append(line.strip().lower())
    print "successfully load all training docs from %s..." % (
        os.path.abspath(train_docs_path))

    # Vocabulary
    vocabulary_path = os.path.join(input_directory, 'voc.dat')
    input_voc_stream = open(vocabulary_path, 'r')
    vocab = []
    for line in input_voc_stream:
        vocab.append(line.strip().lower().split()[0])
    vocab = list(set(vocab))
    print "successfully load all the words from %s..." % (
        os.path.abspath(vocabulary_path))

    # parameter set 3
    alpha_alpha = 1.0 / number_of_topics
    if options.alpha_alpha > 0:
        alpha_alpha = options.alpha_alpha
    alpha_beta = options.alpha_beta
    if alpha_beta <= 0:
        alpha_beta = 1.0 / len(vocab)

    # create output directory
    now = datetime.datetime.now()
    suffix = now.strftime("%y%m%d-%H%M%S") + ""
    suffix += "-%s" % ("lda")
    suffix += "-I%d" % (training_iterations)
    suffix += "-S%d" % (snapshot_interval)
    suffix += "-K%d" % (number_of_topics)
    suffix += "-aa%f" % (alpha_alpha)
    suffix += "-ab%f" % (alpha_beta)
    suffix += "-im%d" % (inference_mode)
    # suffix += "-%s" % (resample_topics);
    # suffix += "-%s" % (hash_oov_words);
    suffix += "/"

    output_directory = os.path.join(output_directory, suffix)
    os.mkdir(os.path.abspath(output_directory))

    #dict_file = options.dictionary;
    #if dict_file != None:
    #dict_file = dict_file.strip();

    # store all the options to a file
    options_output_file = open(output_directory + "option.txt", 'w')
    # parameter set 1
    options_output_file.write("input_directory=" + input_directory + "\n")
    options_output_file.write("corpus_name=" + corpus_name + "\n")
    #options_output_file.write("vocabulary_path=" + str(dict_file) + "\n");
    # parameter set 2
    options_output_file.write("training_iterations=%d\n" %
                              (training_iterations))
    options_output_file.write("snapshot_interval=" + str(snapshot_interval) +
                              "\n")
    options_output_file.write("number_of_topics=" + str(number_of_topics) +
                              "\n")
    # parameter set 3
    options_output_file.write("alpha_alpha=" + str(alpha_alpha) + "\n")
    options_output_file.write("alpha_beta=" + str(alpha_beta) + "\n")
    # parameter set 4
    options_output_file.write("inference_mode=%d\n" % (inference_mode))
    options_output_file.close()

    print "========== ========== ========== ========== =========="
    # parameter set 1
    print "output_directory=" + output_directory
    print "input_directory=" + input_directory
    print "corpus_name=" + corpus_name
    #print "dictionary file=" + str(dict_file)
    # parameter set 2
    print "training_iterations=%d" % (training_iterations)
    print "snapshot_interval=" + str(snapshot_interval)
    print "number_of_topics=" + str(number_of_topics)
    # parameter set 3
    print "alpha_alpha=" + str(alpha_alpha)
    print "alpha_beta=" + str(alpha_beta)
    # parameter set 4
    print "inference_mode=%d" % (inference_mode)
    print "========== ========== ========== ========== =========="

    if inference_mode == 0:
        import hybrid
        lda_inferencer = hybrid.Hybrid()
    elif inference_mode == 1:
        import monte_carlo
        lda_inferencer = monte_carlo.MonteCarlo()
    elif inference_mode == 2:
        import variational_bayes
        lda_inferencer = variational_bayes.VariationalBayes()
    else:
        sys.stderr.write("error: unrecognized inference mode %d...\n" %
                         (inference_mode))
        return

    lda_inferencer._initialize(train_docs, vocab, number_of_topics,
                               alpha_alpha, alpha_beta)

    for iteration in xrange(training_iterations):
        lda_inferencer.learning()

        if (lda_inferencer._counter % snapshot_interval == 0):
            lda_inferencer.export_beta(output_directory + 'exp_beta-' +
                                       str(lda_inferencer._counter))

    model_snapshot_path = os.path.join(output_directory,
                                       'model-' + str(lda_inferencer._counter))
    cPickle.dump(lda_inferencer, open(model_snapshot_path, 'wb'))
コード例 #4
0
def main():
    options = parse_args();

    # parameter set 2
    assert(options.number_of_topics > 0);
    number_of_topics = options.number_of_topics;
    assert(options.training_iterations > 0);
    training_iterations = options.training_iterations;
    assert(options.snapshot_interval > 0);
    if options.snapshot_interval > 0:
        snapshot_interval = options.snapshot_interval;
    
    # parameter set 4
    optimization_method = options.optimization_method;
    if optimization_method == None:
        optimization_method = "L-BFGS-B";
    number_of_processes = options.number_of_processes;
    if number_of_processes <= 0:
        sys.stderr.write("invalid setting for number_of_processes, adjust to 1...\n");
        number_of_processes = 1;
    # diagonal_covariance_matrix = options.diagonal_covariance_matrix;
    
    # parameter set 1
    # assert(options.corpus_name!=None);
    assert(options.input_directory != None);
    assert(options.output_directory != None);
    
    input_directory = options.input_directory;
    input_directory = input_directory.rstrip("/");
    corpus_name = os.path.basename(input_directory);
    
    output_directory = options.output_directory;
    if not os.path.exists(output_directory):
        os.mkdir(output_directory);
    output_directory = os.path.join(output_directory, corpus_name);
    if not os.path.exists(output_directory):
        os.mkdir(output_directory);

    # Document
    train_docs_path = os.path.join(input_directory, 'train.dat')
    input_doc_stream = open(train_docs_path, 'r');
    train_docs = [];
    for line in input_doc_stream:
        train_docs.append(line.strip().lower());
    print("successfully load all training docs from %s..." % (os.path.abspath(train_docs_path)));
    
    # Vocabulary
    vocabulary_path = os.path.join(input_directory, 'voc.dat');
    input_voc_stream = open(vocabulary_path, 'r');
    vocab = [];
    for line in input_voc_stream:
        vocab.append(line.strip().lower().split()[0]);
    vocab = list(set(vocab));
    print("successfully load all the words from %s..." % (os.path.abspath(vocabulary_path)));
    
    # parameter set 3
    alpha_mu = options.alpha_mu;
    # assert(options.alpha_sigma>0);
    alpha_sigma = options.alpha_sigma;
    if alpha_sigma <= 0:
        # alpha_sigma = 1.0/number_of_topics;
        alpha_sigma = 1.0
    assert(alpha_sigma > 0);
    alpha_beta = options.alpha_beta;
    if alpha_beta <= 0:
        alpha_beta = 1.0 / len(vocab);

    # create output directory
    now = datetime.datetime.now();
    suffix = now.strftime("%y%m%d-%H%M%S") + "";
    suffix += "-%s" % ("ctm");
    suffix += "-I%d" % (training_iterations);
    suffix += "-S%d" % (snapshot_interval);
    suffix += "-K%d" % (number_of_topics);
    suffix += "-am%g" % (alpha_mu);
    suffix += "-as%g" % (alpha_sigma);
    suffix += "-ab%g" % (alpha_beta);
    if optimization_method != None:
        suffix += "-%s" % (optimization_method.replace("-", "_"));
    # suffix += "-DCM%s" % (diagonal_covariance_matrix);
    # suffix += "-%s" % (resample_topics);
    # suffix += "-%s" % (hash_oov_words);
    suffix += "/";
    
    output_directory = os.path.join(output_directory, suffix);
    os.mkdir(os.path.abspath(output_directory));

    # dict_file = options.dictionary;
    # if dict_file != None:
        # dict_file = dict_file.strip();
        
    # store all the options to a file
    options_output_file = open(output_directory + "option.txt", 'w');
    # parameter set 1
    options_output_file.write("input_directory=" + input_directory + "\n");
    options_output_file.write("corpus_name=" + corpus_name + "\n");
    # options_output_file.write("vocabulary_path=" + str(dict_file) + "\n");
    # parameter set 2
    options_output_file.write("training_iterations=%d\n" % (training_iterations));
    options_output_file.write("snapshot_interval=" + str(snapshot_interval) + "\n");
    options_output_file.write("number_of_topics=" + str(number_of_topics) + "\n");
    # parameter set 3
    options_output_file.write("alpha_mu=" + str(alpha_mu) + "\n");
    options_output_file.write("alpha_sigma=" + str(alpha_sigma) + "\n");
    options_output_file.write("alpha_beta=" + str(alpha_beta) + "\n");
    # parameter set 4
    options_output_file.write("optimization_method=%s\n" % (optimization_method));
    options_output_file.write("number_of_processes=%d\n" % (number_of_processes));
    # options_output_file.write("diagonal_covariance_matrix=%s\n" % (diagonal_covariance_matrix));
    options_output_file.close()

    print("========== ========== ========== ========== ==========")
    # parameter set 1
    print("output_directory=" + output_directory)
    print("input_directory=" + input_directory)
    print("corpus_name=" + corpus_name)
    # print "dictionary file=" + str(dict_file)
    # parameter set 2
    print("training_iterations=%d" % (training_iterations));
    print("snapshot_interval=" + str(snapshot_interval));
    print("number_of_topics=" + str(number_of_topics))
    # parameter set 3
    print("alpha_mu=" + str(alpha_mu))
    print("alpha_sigma=" + str(alpha_sigma))
    print("alpha_beta=" + str(alpha_beta))
    # parameter set 4
    print("optimization_method=%s" % (optimization_method))
    print("number_of_processes=%d" % (number_of_processes))
    # print "diagonal_covariance_matrix=%s" % (diagonal_covariance_matrix)
    print("========== ========== ========== ========== ==========")
    
    '''
    if inference_mode==0:
        import hybrid
        ctm_inferencer = hybrid.Hybrid();
    elif inference_mode==1:
        import monte_carlo
        ctm_inferencer = monte_carlo.MonteCarlo();
    elif inference_mode==2:
        import variational_bayes
        ctm_inferencer = variational_bayes.VariationalBayes();
    else:
        sys.stderr.write("error: unrecognized inference mode %d...\n" % (inference_mode));
        return;
    '''
    
    import variational_bayes
    ctm_inferencer = variational_bayes.VariationalBayes(optimization_method);
    
    ctm_inferencer._initialize(train_docs, vocab, number_of_topics, alpha_mu, alpha_sigma, alpha_beta);
    
    for iteration in range(training_iterations):
        ctm_inferencer.learning(number_of_processes);
        
        if (ctm_inferencer._counter % snapshot_interval == 0):
            ctm_inferencer.export_beta(os.path.join(output_directory, 'exp_beta-' + str(ctm_inferencer._counter)));
            model_snapshot_path = os.path.join(output_directory, 'model-' + str(ctm_inferencer._counter));
            pickle.dump(ctm_inferencer, open(model_snapshot_path, 'wb'));
            
    model_snapshot_path = os.path.join(output_directory, 'model-' + str(ctm_inferencer._counter));
    pickle.dump(ctm_inferencer, open(model_snapshot_path, 'wb'));
コード例 #5
0
ファイル: launch_train.py プロジェクト: omo03/PyIBP
def main():
    options = parse_args()

    # parameter set 2
    # assert(options.number_of_topics>0);
    # number_of_topics = options.number_of_topics;
    assert (options.training_iterations > 0)
    training_iterations = options.training_iterations
    assert (options.snapshot_interval > 0)
    if options.snapshot_interval > 0:
        snapshot_interval = options.snapshot_interval

    # parameter set 4
    # disable_alpha_theta_update = options.disable_alpha_theta_update;
    inference_mode = options.inference_mode
    if inference_mode == 1:
        sampler_type = options.sampler_type
    elif inference_mode == 2:
        assert (options.truncation_level > 0)
        truncation_level = options.truncation_level

    # parameter set 1
    # assert(options.dataset_name!=None);
    assert (options.input_directory != None)
    assert (options.output_directory != None)

    input_directory = options.input_directory
    input_directory = input_directory.rstrip("/")
    dataset_name = os.path.basename(input_directory)

    output_directory = options.output_directory
    if not os.path.exists(output_directory):
        os.mkdir(output_directory)
    output_directory = os.path.join(output_directory, dataset_name)
    if not os.path.exists(output_directory):
        os.mkdir(output_directory)

    # Dataset
    train_file_path = os.path.join(input_directory, 'train.dat')
    train_data = numpy.loadtxt(train_file_path)
    train_data = center_data(train_data)
    print "successfully load all train_data from %s..." % (
        os.path.abspath(train_file_path))

    # parameter set 3
    assert (options.alpha_alpha > 0)
    alpha_alpha = options.alpha_alpha
    assert (options.sigma_a > 0)
    sigma_a = options.sigma_a
    assert (options.sigma_x > 0)
    sigma_x = options.sigma_x

    # create output directory
    now = datetime.datetime.now()
    suffix = now.strftime("%y%m%d-%H%M%S") + ""
    suffix += "-%s" % ("ibp")
    suffix += "-I%d" % (training_iterations)
    suffix += "-S%d" % (snapshot_interval)
    suffix += "-a%f" % (alpha_alpha)
    suffix += "-sa%f" % (sigma_a)
    suffix += "-sx%f" % (sigma_x)
    suffix += "-im%d" % (inference_mode)
    if inference_mode == 1:
        suffix += "-st%s" % (sampler_type)
    elif inference_mode == 2:
        suffix += "-T%d" % (truncation_level)
    suffix += "/"

    output_directory = os.path.join(output_directory, suffix)
    os.mkdir(os.path.abspath(output_directory))

    # dict_file = options.dictionary;
    # if dict_file != None:
    # dict_file = dict_file.strip();

    # store all the options to a file
    options_output_file = open(output_directory + "option.txt", 'w')
    # parameter set 1
    options_output_file.write("input_directory=" + input_directory + "\n")
    options_output_file.write("dataset_name=" + dataset_name + "\n")
    # options_output_file.write("vocabulary_path=" + str(dict_file) + "\n");
    # parameter set 2
    options_output_file.write("training_iterations=%d\n" %
                              (training_iterations))
    options_output_file.write("snapshot_interval=" + str(snapshot_interval) +
                              "\n")
    # options_output_file.write("number_of_topics=" + str(number_of_topics) + "\n");
    # parameter set 3
    options_output_file.write("alpha_alpha=" + str(alpha_alpha) + "\n")
    options_output_file.write("sigma_a=" + str(sigma_a) + "\n")
    options_output_file.write("sigma_x=" + str(sigma_x) + "\n")
    # parameter set 4
    options_output_file.write("inference_mode=%d\n" % (inference_mode))
    if inference_mode == 1:
        options_output_file.write("sampler_type=%d\n" % (sampler_type))
    elif inference_mode == 2:
        options_output_file.write("truncation_level=%d\n" % (truncation_level))
    options_output_file.close()

    print "========== ========== ========== ========== =========="
    # parameter set 1
    print "output_directory=" + output_directory
    print "input_directory=" + input_directory
    print "dataset_name=" + dataset_name
    # print "dictionary file=" + str(dict_file)
    # parameter set 2
    print "training_iterations=%d" % (training_iterations)
    print "snapshot_interval=" + str(snapshot_interval)
    # print "number_of_topics=" + str(number_of_topics)
    # parameter set 3
    print "alpha_alpha=" + str(alpha_alpha)
    print "sigma_a=" + str(sigma_a)
    print "sigma_x=" + str(sigma_x)
    # parameter set 4
    print "inference_mode=%d" % (inference_mode)
    if inference_mode == 1:
        print "sampler_type=%d" % (sampler_type)
    elif inference_mode == 2:
        print "truncation_level=%d" % (truncation_level)
    print "========== ========== ========== ========== =========="

    # if inference_mode==0:
    # import hybrid
    # ibp_inferencer = hybrid.Hybrid();
    if inference_mode == 1:
        if sampler_type == 1:
            import collapsed_gibbs
            ibp_inferencer = collapsed_gibbs.CollapsedGibbs()
        elif sampler_type == 2:
            import semicollapsed_gibbs
            ibp_inferencer = semicollapsed_gibbs.SemiCollapsedGibbs()
        elif sampler_type == 3:
            import uncollapsed_gibbs
            ibp_inferencer = uncollapsed_gibbs.UncollapsedGibbs()
        else:
            sys.stderr.write("error: unrecognized sampler type %d...\n" %
                             (sampler_type))
            return
        ibp_inferencer._initialize(train_data,
                                   alpha_alpha,
                                   sigma_a,
                                   sigma_x,
                                   initial_Z=None,
                                   A_prior=None)
    elif inference_mode == 2:
        import variational_bayes
        ibp_inferencer = variational_bayes.VariationalBayes()
        ibp_inferencer._initialize(train_data, truncation_level, alpha_alpha,
                                   sigma_a, sigma_x)
    else:
        sys.stderr.write("error: unrecognized inference mode %d...\n" %
                         (inference_mode))
        return

    for iteration in xrange(training_iterations):
        log_likelihood = ibp_inferencer.learning()

        print "iteration: %i\tK: %i\tlikelihood: %f" % (
            ibp_inferencer._counter, ibp_inferencer._K, log_likelihood)

        if (ibp_inferencer._counter % snapshot_interval == 0):
            ibp_inferencer.export_snapshot(output_directory)

        #print ibp_inferencer._Z.sum(axis=0)

    model_snapshot_path = os.path.join(output_directory,
                                       'model-' + str(ibp_inferencer._counter))
    cPickle.dump(ibp_inferencer, open(model_snapshot_path, 'wb'))