Пример #1
0
Файл: train.py Проект: yyht/rrws
    def __call__(self, iteration, loss, elbo, generative_model,
                 inference_network, control_variate):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} loss = {:.3f}, elbo = {:.3f}'.format(
                    iteration, loss, elbo))
            self.loss_history.append(loss)
            self.elbo_history.append(elbo)

        if iteration % self.checkpoint_interval == 0:
            stats_filename = util.get_stats_filename(self.model_folder)
            util.save_object(self, stats_filename)
            util.save_models(generative_model, inference_network,
                             self.pcfg_path, self.model_folder)
            util.save_control_variate(control_variate, self.model_folder)

        if iteration % self.eval_interval == 0:
            self.p_error_history.append(
                util.get_p_error(self.true_generative_model, generative_model))
            self.q_error_to_true_history.append(
                util.get_q_error(self.true_generative_model,
                                 inference_network))
            self.q_error_to_model_history.append(
                util.get_q_error(generative_model, inference_network))
            util.print_with_time(
                'Iteration {} p_error = {:.3f}, q_error_to_true = {:.3f}, '
                'q_error_to_model = {:.3f}'.format(
                    iteration, self.p_error_history[-1],
                    self.q_error_to_true_history[-1],
                    self.q_error_to_model_history[-1]))
Пример #2
0
def main():
    batch_size = 2
    pcfg_path = './pcfgs/astronomers_pcfg.json'
    generative_model, inference_network, true_generative_model = \
        util.init_models(pcfg_path)
    obss = [true_generative_model.sample_obs() for _ in range(batch_size)]

    num_mc_samples = 100
    num_particles_list = [2, 5, 10, 20, 50, 100]

    vimco_grad = np.zeros((len(num_particles_list), 2))
    vimco_one_grad = np.zeros((len(num_particles_list), 2))
    reinforce_grad = np.zeros((len(num_particles_list), 2))
    reinforce_one_grad = np.zeros((len(num_particles_list), 2))
    two_grad = np.zeros((len(num_particles_list), 2))
    log_evidence_stats = np.zeros((len(num_particles_list), 2))
    log_evidence_grad = np.zeros((len(num_particles_list), 2))
    wake_phi_loss_grad = np.zeros((len(num_particles_list), 2))
    log_Q_grad = np.zeros((len(num_particles_list), 2))
    sleep_loss_grad = np.zeros((len(num_particles_list), 2))

    for i, num_particles in enumerate(num_particles_list):
        util.print_with_time('num_particles = {}'.format(num_particles))
        (vimco_grad[i], vimco_one_grad[i], reinforce_grad[i],
         reinforce_one_grad[i], two_grad[i], log_evidence_stats[i],
         log_evidence_grad[i], wake_phi_loss_grad[i], log_Q_grad[i],
         sleep_loss_grad[i]) = get_mean_stds(
            generative_model, inference_network, num_mc_samples, obss,
            num_particles)

    util.save_object([
        vimco_grad, vimco_one_grad, reinforce_grad,  reinforce_one_grad,
        two_grad, log_evidence_stats, log_evidence_grad, wake_phi_loss_grad,
        log_Q_grad, sleep_loss_grad],
        './variance_analysis/data.pkl')
Пример #3
0
def train_and_save(classifier_list, classifier_name_list, training_data):
    results = train_multiple_classifiers(classifier_list, classifier_name_list,
                                         training_data)
    util.save_object(results, RESULTS_PATH)
    util.save_classifier_list(classifier_list, classifier_name_list,
                              CLASSIFIERS_AND_RESULTS_DIR_PATH)
    return results
Пример #4
0
    def __call__(self, iteration, theta_loss, phi_loss, generative_model,
                 inference_network, memory, optimizer):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} losses: theta = {:.3f}, phi = {:.3f}'.format(
                    iteration, theta_loss, phi_loss))
            self.theta_loss_history.append(theta_loss)
            self.phi_loss_history.append(phi_loss)

        if iteration % self.checkpoint_interval == 0:
            stats_filename = util.get_stats_path(self.model_folder)
            util.save_object(self, stats_filename)
            util.save_models(generative_model, inference_network,
                             self.model_folder, iteration, memory)

        if iteration % self.eval_interval == 0:
            self.p_error_history.append(
                util.get_p_error(self.true_generative_model, generative_model))
            self.q_error_history.append(
                util.get_q_error(self.true_generative_model, inference_network,
                                 self.test_obss))
            # TODO
            # self.memory_error_history.append(util.get_memory_error(
            #     self.true_generative_model, memory, generative_model,
            #     self.test_obss))
            util.print_with_time(
                'Iteration {} p_error = {:.3f}, q_error_to_true = '
                '{:.3f}'.format(iteration, self.p_error_history[-1],
                                self.q_error_history[-1]))
Пример #5
0
Файл: train.py Проект: yyht/rrws
    def __call__(self, iteration, wake_theta_loss, wake_phi_loss, elbo,
                 generative_model, inference_network, optimizer_theta,
                 optimizer_phi):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} losses: theta = {:.3f}, phi = {:.3f}, elbo = '
                '{:.3f}'.format(iteration, wake_theta_loss, wake_phi_loss,
                                elbo))
            self.wake_theta_loss_history.append(wake_theta_loss)
            self.wake_phi_loss_history.append(wake_phi_loss)
            self.elbo_history.append(elbo)

        if iteration % self.checkpoint_interval == 0:
            stats_filename = util.get_stats_filename(self.model_folder)
            util.save_object(self, stats_filename)
            util.save_models(generative_model, inference_network,
                             self.pcfg_path, self.model_folder)

        if iteration % self.eval_interval == 0:
            self.p_error_history.append(
                util.get_p_error(self.true_generative_model, generative_model))
            self.q_error_to_true_history.append(
                util.get_q_error(self.true_generative_model,
                                 inference_network))
            self.q_error_to_model_history.append(
                util.get_q_error(generative_model, inference_network))
            util.print_with_time(
                'Iteration {} p_error = {:.3f}, q_error_to_true = {:.3f}, '
                'q_error_to_model = {:.3f}'.format(
                    iteration, self.p_error_history[-1],
                    self.q_error_to_true_history[-1],
                    self.q_error_to_model_history[-1]))
Пример #6
0
def run(args):
    util.print_with_time(str(args))

    # save args
    model_folder = util.get_model_folder()
    args_filename = util.get_args_filename(model_folder)
    util.save_object(args, args_filename)

    # init models
    util.set_seed(args.seed)
    generative_model, inference_network, true_generative_model = \
        load_or_init_models(args.load_model_folder, args.pcfg_path)
    if args.train_mode == 'relax':
        control_variate = models.ControlVariate(generative_model.grammar)

    # train
    if args.train_mode == 'ww':
        train_callback = train.TrainWakeWakeCallback(
            args.pcfg_path, model_folder, true_generative_model,
            args.logging_interval, args.checkpoint_interval,
            args.eval_interval)
        train.train_wake_wake(generative_model, inference_network,
                              true_generative_model, args.batch_size,
                              args.num_iterations, args.num_particles,
                              train_callback)
    elif args.train_mode == 'ws':
        train_callback = train.TrainWakeSleepCallback(
            args.pcfg_path, model_folder, true_generative_model,
            args.logging_interval, args.checkpoint_interval,
            args.eval_interval)
        train.train_wake_sleep(generative_model, inference_network,
                               true_generative_model, args.batch_size,
                               args.num_iterations, args.num_particles,
                               train_callback)
    elif args.train_mode == 'reinforce' or args.train_mode == 'vimco':
        train_callback = train.TrainIwaeCallback(
            args.pcfg_path, model_folder, true_generative_model,
            args.logging_interval, args.checkpoint_interval,
            args.eval_interval)
        train.train_iwae(args.train_mode, generative_model, inference_network,
                         true_generative_model, args.batch_size,
                         args.num_iterations, args.num_particles,
                         train_callback)
    elif args.train_mode == 'relax':
        train_callback = train.TrainRelaxCallback(
            args.pcfg_path, model_folder, true_generative_model,
            args.logging_interval, args.checkpoint_interval,
            args.eval_interval)
        train.train_relax(generative_model, inference_network, control_variate,
                          true_generative_model, args.batch_size,
                          args.num_iterations, args.num_particles,
                          train_callback)

    # save models and stats
    util.save_models(generative_model, inference_network, args.pcfg_path,
                     model_folder)
    stats_filename = util.get_stats_filename(model_folder)
    util.save_object(train_callback, stats_filename)
Пример #7
0
def pre_sentence_token(df, lower=True, makevocab=True):
    if makevocab:
        vocab_counter = Counter()
    all_titles = []
    all_contents = []
    file_i = 1
    for index, row in df.iterrows():
        try:
            title = str(row['title']).strip("'<>() ")
            content = str(row['content']).strip("'<>() ")
            if lower:
                title = title.lower()
                content = content.lower()
            title_words = word_tokenizer(title)
            if len(title_words)==1:
                print('ignore id: {} sentence...,due to title length is 1'.format(row['id']))
                continue
            if makevocab:
                vocab_counter.update(title_words)
            contents = [] # 存放内容每句的句子
            content_sentences = sentence_tokenizer(content)
            for content_s in content_sentences:
                content_words = word_tokenizer(content_s)
                if makevocab:
                    vocab_counter.update(content_words)
                contents.append(content_words)
            ### bulid Vocab
            if title_words and contents:
                # 标题和内容都不为空
                all_titles.append(title_words)
                all_contents.append(contents)
            if index % 5000 == 0:
                print('has deal with %d sentence...' % index)
            # 每100000存储一个文件
            if index % 100000 == 0:
                data = {}
                data['article'] = all_contents
                data['abstract'] = all_titles
                util.save_object(data, './preprocessed/train/all_data_{}.pickle'.format(file_i))
                file_i += 1
                # 清空data,和数组
                all_contents = []
                all_titles = []
                data.clear()
        except Exception:
            print('id :{} sentence can not deal with, because the sentence has exception...'.format(row['id']))
            continue
    # save data
    data = {}
    data['article'] = all_contents
    data['abstract'] = all_titles
    util.save_object(data, './preprocessed/train/all_data_{}.pickle'.format(file_i))
    print("Writing vocab file...")
    with open('./preprocessed/all_data_vocab_200000', 'w') as writer:
        for word, count in vocab_counter.most_common(200000):
            writer.write(word + ' ' + str(count) + '\n')
    nlp_tokenizer.close()
Пример #8
0
 def __init__(self, filename_data="data/sample_conversations.json", filename_pkl="autocomplete_state.pkl", load=False):
     if load:
         self.load_from_file(filename_pkl)
         return
     self.tt = Trie()
     data = util.read_input(filename_data)
     for line, count in util.get_customer_service_phrases(data).items():
         for i in range(count):
             self.tt.add(line)
     util.save_object(self.tt, filename_pkl)
Пример #9
0
 def on_grabar(self, event):
     fichero = dialogo_grabar_fic(self, "pgc", "Archivos pgc (*.pgc)|*.pgc|Todos los archivos (*.*)|*", "")
     if fichero:
         try:
             busy = wx.BusyInfo("Crabando...", self)
             util.save_object(self.data, fichero)
             # with open(fichero, 'w') as f:
             #     pickle.dump(self.data.encode("zlib").encode('hex'), f)
         except IOError, e:
             wx.MessageBox(str(e), APLICACION, wx.ICON_ERROR, self)
         finally:
Пример #10
0
    def __call__(self, iteration, wake_theta_loss, wake_phi_loss, elbo,
                 generative_model, inference_network, optimizer_theta,
                 optimizer_phi):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} losses: theta = {:.3f}, phi = {:.3f}, elbo = '
                '{:.3f}'.format(iteration, wake_theta_loss, wake_phi_loss,
                                elbo))
            self.wake_theta_loss_history.append(wake_theta_loss)
            self.wake_phi_loss_history.append(wake_phi_loss)
            self.elbo_history.append(elbo)

        if iteration % self.checkpoint_interval == 0:
            stats_path = util.get_stats_path(self.save_dir)
            util.save_object(self, stats_path)
            util.save_checkpoint(self.save_dir,
                                 iteration,
                                 generative_model=generative_model,
                                 inference_network=inference_network)

        if iteration % self.eval_interval == 0:
            log_p, kl = eval_gen_inf(generative_model, inference_network,
                                     self.test_data_loader,
                                     self.eval_num_particles)
            self.log_p_history.append(log_p)
            self.kl_history.append(kl)

            stats = util.OnlineMeanStd()
            for _ in range(10):
                generative_model.zero_grad()
                wake_theta_loss, elbo = losses.get_wake_theta_loss(
                    generative_model, inference_network, self.test_obs,
                    self.num_particles)
                wake_theta_loss.backward()
                theta_grads = [
                    p.grad.clone() for p in generative_model.parameters()
                ]

                inference_network.zero_grad()
                wake_phi_loss = losses.get_wake_phi_loss(
                    generative_model, inference_network, self.test_obs,
                    self.num_particles)
                wake_phi_loss.backward()
                phi_grads = [p.grad for p in inference_network.parameters()]

                stats.update(theta_grads + phi_grads)
            self.grad_std_history.append(stats.avg_of_means_stds()[1].item())
            util.print_with_time(
                'Iteration {} log_p = {:.3f}, kl = {:.3f}'.format(
                    iteration, self.log_p_history[-1], self.kl_history[-1]))
Пример #11
0
def read_from_spacy_and_save():
    print('Reading embeddings from spacy (GloVe)')
    print('Train features')
    train_nlp = [tn.spacy_english_model(item) for item in train_corpus]
    util.save_object(
        train_nlp, CLASSIFIERS_AND_RESULTS_DIR_PATH + 'train_nlp_' +
        str(CLASSIFIER_ITERATION) + '.pkl')
    train_glove_features = np.array([item.vector for item in train_nlp])
    print('Test features')
    test_nlp = [tn.spacy_english_model(item) for item in test_corpus]
    util.save_object(
        train_nlp, CLASSIFIERS_AND_RESULTS_DIR_PATH + 'test_nlp_' +
        str(CLASSIFIER_ITERATION) + '.pkl')
    test_glove_features = np.array([item.vector for item in test_nlp])
    return train_glove_features, test_glove_features
Пример #12
0
    def reformat_and_save_data_for_fasttext(self):
        ft_train_data_formatted = ''
        for i in range(0, len(self.train_corpus)):
            if i in self.data_df.index:
                ft_train_data_formatted += '__label__' + self.train_label_names[i] + \
                                           ' ' + self.train_corpus[i] + '\n'
        util.save_object(ft_train_data_formatted,
                         self.TRAIN_DATA_FOR_FASTTEXT_PATH)

        ft_test_data_formatted = ''
        for i in range(0, len(self.test_corpus)):
            if i in self.data_df.index:
                ft_test_data_formatted += '__label__' + self.test_label_names[
                    i] + ' ' + self.test_corpus[i] + '\n'
        util.save_object(ft_test_data_formatted,
                         self.TEST_DATA_FOR_FASTTEXT_PATH)
Пример #13
0
def main():
    num_mixtures = 20
    temp = np.arange(num_mixtures) + 5
    true_p_mixture_probs = temp / np.sum(temp)
    softmax_multiplier = 0.5
    args = argparse.Namespace(
        init_mixture_logits=np.array(
            list(reversed(2 * np.arange(num_mixtures)))),
        softmax_multiplier=softmax_multiplier,
        device=torch.device('cpu'),
        num_mixtures=num_mixtures,
        relaxed_one_hot=False,
        temperature=None,
        true_mixture_logits=np.log(true_p_mixture_probs) / softmax_multiplier)
    batch_size = 2
    generative_model, inference_network, true_generative_model = \
        util.init_models(args)
    obs = true_generative_model.sample_obs(batch_size)

    num_mc_samples = 100
    num_particles_list = [2, 5, 10, 20, 50, 100]

    vimco_grad = np.zeros((len(num_particles_list), 2))
    vimco_one_grad = np.zeros((len(num_particles_list), 2))
    reinforce_grad = np.zeros((len(num_particles_list), 2))
    reinforce_one_grad = np.zeros((len(num_particles_list), 2))
    two_grad = np.zeros((len(num_particles_list), 2))
    log_evidence_stats = np.zeros((len(num_particles_list), 2))
    log_evidence_grad = np.zeros((len(num_particles_list), 2))
    wake_phi_loss_grad = np.zeros((len(num_particles_list), 2))
    log_Q_grad = np.zeros((len(num_particles_list), 2))
    sleep_loss_grad = np.zeros((len(num_particles_list), 2))

    for i, num_particles in enumerate(num_particles_list):
        util.print_with_time('num_particles = {}'.format(num_particles))
        (vimco_grad[i], vimco_one_grad[i], reinforce_grad[i],
         reinforce_one_grad[i], two_grad[i], log_evidence_stats[i],
         log_evidence_grad[i], wake_phi_loss_grad[i], log_Q_grad[i],
         sleep_loss_grad[i]) = get_mean_stds(generative_model,
                                             inference_network, num_mc_samples,
                                             obs, num_particles)

    util.save_object([
        vimco_grad, vimco_one_grad, reinforce_grad, reinforce_one_grad,
        two_grad, log_evidence_stats, log_evidence_grad, wake_phi_loss_grad,
        log_Q_grad, sleep_loss_grad
    ], './variance_analysis/data.pkl')
Пример #14
0
def train_gensim_basic(model_name, nb_matches):
    # loads data

    heroes, victories = preprocess_data(nb_matches)
    print("data shape : ", heroes.shape, victories.shape)

    model, sentences = gensim_modeling.preprocess_for_keras(heroes, victories)
    picks, victories = gensim_modeling.add_heroes_to_get_team(
        model, sentences, victories)
    victories = np.concatenate([victories, victories])

    embedded_size = 1000
    print(picks.shape, victories.shape)

    m = model_name(embedded_size)
    model_name = str(model_name)
    m.summary()

    opt = Adam(lr=0.0001)
    metrics = ['binary_accuracy']

    t = time.time()
    tb = TensorBoard(log_dir=f"./logs/{t}_" + model_name,
                     write_graph=True,
                     write_images=True)

    callbacks = [
        ModelCheckpoint(filepath='models/' + model_name, period=100), tb
    ]

    m.compile(optimizer=opt, loss='binary_crossentropy', metrics=metrics)
    history = None
    try:
        history = m.fit(x=picks,
                        y=victories,
                        batch_size=64,
                        callbacks=callbacks,
                        verbose=1,
                        epochs=500,
                        validation_split=0.15)

    except KeyboardInterrupt:
        print('\n Saving history of training to : ' + model_name + '_history')
        save_object(history, model_name + '_history')
Пример #15
0
    def __call__(self, iteration, loss, elbo, generative_model,
                 inference_network, optimizer):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} loss = {:.3f}, elbo = {:.3f}'.format(
                    iteration, loss, elbo))
            self.loss_history.append(loss)
            self.elbo_history.append(elbo)

        if iteration % self.checkpoint_interval == 0:
            stats_path = util.get_stats_path(self.save_dir)
            util.save_object(self, stats_path)
            util.save_checkpoint(self.save_dir,
                                 iteration,
                                 generative_model=generative_model,
                                 inference_network=inference_network)

        if iteration % self.eval_interval == 0:
            log_p, kl = eval_gen_inf(generative_model, inference_network,
                                     self.test_data_loader,
                                     self.eval_num_particles)
            _, renyi = eval_gen_inf_alpha(generative_model, inference_network,
                                          self.test_data_loader,
                                          self.eval_num_particles, self.alpha)
            self.log_p_history.append(log_p)
            self.kl_history.append(kl)
            self.renyi_history.append(renyi)

            stats = util.OnlineMeanStd()
            for _ in range(10):
                generative_model.zero_grad()
                inference_network.zero_grad()
                loss, elbo = losses.get_thermo_alpha_loss(
                    generative_model, inference_network, self.test_obs,
                    self.partition, self.num_particles, self.alpha,
                    self.integration)
                loss.backward()
                stats.update([p.grad for p in generative_model.parameters()] +
                             [p.grad for p in inference_network.parameters()])
            self.grad_std_history.append(stats.avg_of_means_stds()[1].item())
            util.print_with_time(
                'Iteration {} log_p = {:.3f}, kl = {:.3f}, renyi = {:.3f}'.
                format(iteration, self.log_p_history[-1], self.kl_history[-1],
                       self.renyi_history[-1]))
def test_trie():
    # some basic functions
    tt = autocomplete.Trie()
    tt.add("Hello world")
    assert len(tt) == 1
    assert "Hello world" in tt
    tt.add("Hello World")
    assert len(tt) == 2
    assert "Hello World" in tt
    assert "Hello world" in tt
    assert "hello world - i'm not supposed to be in the trie" not in tt
    assert "Hello World aaaaaaaaaaaaaa" not in tt
    tt.add("Hello World")
    assert len(tt) == 2
    assert "Hello World" in tt
    assert "Hello world" in tt
    assert "hello world" not in tt
    assert "" not in tt
    assert "H" not in tt
    assert "" not in tt
    assert tt.__contains__("Hello", check_end=False)
    tt.clear()
    assert len(tt) == 0
    print("basic tests cleared")

    # set equivalence and memory constraints
    data = util.read_input("data/sample_conversations.json")
    all_convos_set = set()
    for line, count in util.get_customer_service_phrases(data).items():
        for i in range(count):
            # test duplication
            all_convos_set.add(line)
            tt.add(line)

    for line in all_convos_set:
        assert line in tt

    assert len(tt) == len(all_convos_set)
    print("I have %d phrases saved now" % len(tt))
    print("large tests cleared")
    util.save_object(tt, "test_autocomplete_state.pkl")
Пример #17
0
def build_topic_model(data, stop_nfile, num_topics = 50, save = True):
    stop_words = util.stopwordslist(stop_nfile)
    corpora_documents = [] # 分词好的语料
    for index, row in data.iterrows():
        doc = str(row['doc']).strip()
        doc_seg = list(jieba.cut(doc))
        doc_seg_no_stop = [word for word in doc_seg if word not in stop_words]
        corpora_documents.append(doc_seg_no_stop)
        if index%3000 == 0:
            print('deal with sentence %d'%index)
    corpora_dict = corpora.Dictionary(corpora_documents)
    if save:
        util.save_object(corpora_documents, corpus_docs_seg_path)
        util.save_object(corpora_dict, corpora_dict_path)
    # corpora_documents = load_object('./data/docs_seg.pickle')
    # corpora_dict = load_object('./data/dict.pickle')

    corpus = [corpora_dict.doc2bow(doc) for doc in corpora_documents]
    # corpus每个元素为(word_id, fre)表示某个word在该doc中的fre词频
    # save corpus
    if save:
        corpora.MmCorpus.serialize(corpus_path,corpus)
    # load corpus
    # corpus = corpora.MmCorpus('./data/corpus.mm')

    # tf-idf model
    tfidf_model = models.TfidfModel(corpus)
    print('tf-idf model finish...')

    corpus_tfidf = tfidf_model[corpus]
    # lda model
    lda_model = models.LdaModel(corpus_tfidf, id2word=corpora_dict, num_topics=num_topics)
    print('lda model finish...')
    # lsi model
    # corpus_tfidf = tfidf_model[corpus]
    lsi_model = models.LsiModel(corpus_tfidf,id2word=corpora_dict,num_topics=num_topics)
    print('lsi model finish...')
    if save:
        tfidf_model.save(tfidf_path)
        lda_model.save(lda_path)
        lsi_model.save(lsi_path)
Пример #18
0
def train_multiple_classifiers(classifier_list, classifier_name_list,
                               training_data: TrainingData,
                               CLASSIFIERS_AND_RESULTS_DIR_PATH,
                               CLASSIFIER_ITERATION, RESULTS_PATH):
    if len(classifier_list) != len(classifier_name_list):
        print(
            "Classifier list length and classifier name list length must be equal!"
        )
        return

    results_list = []
    for i in range(0, len(classifier_list)):
        results = train_classifier_and_display_results(classifier_list[i],
                                                       classifier_name_list[i],
                                                       training_data)
        results_list.append(results)

        util.save_object(
            results, CLASSIFIERS_AND_RESULTS_DIR_PATH +
            util.convert_name_to_filename(classifier_name_list[i]) + '_' +
            str(CLASSIFIER_ITERATION) + '_results.pkl')
        util.save_object(
            classifier_list[i], CLASSIFIERS_AND_RESULTS_DIR_PATH +
            util.convert_name_to_filename(classifier_name_list[i]) + '_' +
            str(CLASSIFIER_ITERATION) + '.pkl')
    util.save_object(results_list, RESULTS_PATH)
    return results_list
Пример #19
0
    def __call__(self, iteration, loss, elbo, generative_model,
                 inference_network, optimizer):
        if iteration % self.logging_interval == 0:
            util.print_with_time(
                'Iteration {} loss = {:.3f}, elbo = {:.3f}'.format(
                    iteration, loss, elbo))
            self.loss_history.append(loss)
            self.elbo_history.append(elbo)

        if iteration % self.checkpoint_interval == 0:
            stats_filename = util.get_stats_path(self.model_folder)
            util.save_object(self, stats_filename)
            util.save_models(generative_model, inference_network,
                             self.model_folder, iteration)

        if iteration % self.eval_interval == 0:
            self.p_error_history.append(
                util.get_p_error(self.true_generative_model, generative_model))
            self.q_error_history.append(
                util.get_q_error(self.true_generative_model, inference_network,
                                 self.test_obss))
            stats = util.OnlineMeanStd()
            for _ in range(10):
                inference_network.zero_grad()
                if self.train_mode == 'vimco':
                    loss, elbo = losses.get_vimco_loss(generative_model,
                                                       inference_network,
                                                       self.test_obss,
                                                       self.num_particles)
                elif self.train_mode == 'reinforce':
                    loss, elbo = losses.get_reinforce_loss(
                        generative_model, inference_network, self.test_obss,
                        self.num_particles)
                loss.backward()
                stats.update([p.grad for p in inference_network.parameters()])
            self.grad_std_history.append(stats.avg_of_means_stds()[1])
            util.print_with_time(
                'Iteration {} p_error = {:.3f}, q_error_to_true = '
                '{:.3f}'.format(iteration, self.p_error_history[-1],
                                self.q_error_history[-1]))
Пример #20
0
 def get_features(self):
     print('Get features:', self.feature_extraction_method)
     if self.feature_extraction_method == FeatureExtractionMethod.BOW:
         return get_simple_bag_of_words_features(self.train_corpus,
                                                 self.test_corpus)
     elif self.feature_extraction_method == FeatureExtractionMethod.TF_IDF:
         return get_tf_idf_features(self.train_corpus, self.test_corpus)
     elif self.feature_extraction_method == FeatureExtractionMethod.WORD2VEC:
         if self.should_load_embedding_model:
             print('Loading embedding model from disk')
             self.embedding_model = util.load_object(
                 self.WORD2VEC_MODEL_SAVE_PATH)
         else:
             print('Calculating embeddings')
             self.embedding_model = get_word2vec_trained_model(
                 self.tokenized_test, self.NUM_OF_VEC_FEATURES)
             util.save_object(
                 self.embedding_model,
                 self.CLASSIFIERS_AND_RESULTS_DIR_PATH + 'w2v_model_' +
                 str(self.classifier_iter) + '.pkl')
         return self.get_document_embeddings_from_word2vec()
     elif self.feature_extraction_method == FeatureExtractionMethod.FASTTEXT:
         if self.should_load_embedding_model:
             print('Loading embedding model from disk')
             self.embedding_model = fasttext.load_model(
                 self.FAST_TEXT_SAVE_PATH)
         else:
             print('Calculating embeddings')
             if not os.path.exists(self.TRAIN_DATA_FOR_FASTTEXT_PATH):
                 self.reformat_and_save_data_for_fasttext()
             self.embedding_model = train_fasttext_model(
                 self.TRAIN_DATA_FOR_FASTTEXT_PATH,
                 self.NUM_OF_VEC_FEATURES,
                 epoch=100)
             self.embedding_model.save_model(self.FAST_TEXT_SAVE_PATH)
         return self.get_document_embeddings_from_fasttext()
     else:
         print('No such feature extraction method:',
               self.feature_extraction_method)
Пример #21
0
def train_basic(model_name, model_type):
    # loads data

    heroes, victories = preprocess_data()
    print("data shape : ", heroes.shape, victories.shape)

    picks, victories = to_hero_index_and_augmentation(heroes, victories)

    embedded_size = 1000
    m = model_type(embedded_size)
    m.summary()

    opt = Adam(lr=0.00001)
    metrics = ['binary_accuracy']

    t = time.time()
    tb = TensorBoard(log_dir=f"./logs/{t}_" + model_name,
                     write_graph=True,
                     write_images=True)

    callbacks = [
        ModelCheckpoint(filepath='models/' + model_name, period=100), tb
    ]

    m.compile(optimizer=opt, loss='binary_crossentropy', metrics=metrics)
    history = None
    try:
        history = m.fit(x=picks,
                        y=victories,
                        batch_size=128,
                        callbacks=callbacks,
                        verbose=1,
                        epochs=50,
                        validation_split=0.15)

    except KeyboardInterrupt:
        print('\n Saving history of training to : ' + model_name + '_history')
        save_object(history, model_name + '_history')
Пример #22
0
def main(dirname, starting_seq_num, nb_of_seq):

    nb_of_hundreds_by_seq = 200

    current_seq_num = starting_seq_num
    #those are currents limits parsed (up to for the first, end of last batch for second)
    # current_seq_num = 3302933279
    # current_seq_num = 3251588525

    for i in range(nb_of_seq):
        try:
            a = current_seq_num
            batch, current_seq_num = createMatchList(current_seq_num,
                                                     nb_of_hundreds_by_seq,
                                                     False)
            print(current_seq_num)
            t = time.strftime("%d_%b_%H:%M", time.gmtime())
            filename = dirname + f'/seq_start_{a}_seq_end_{current_seq_num}_nbhundreds{nb_of_hundreds_by_seq}' + t + '.pkl'
            util.save_object(batch, filename)
        except APIError:
            traceback.print_exc()
            print(f'Something went wrong on {i}')
            pass
Пример #23
0
 def save(self):
     from util import save_object
     save_object(self,self.savestate)
name = b'a'
email = b'*****@*****.**'
# 2000-01-01T00:00:00+0000
date = b'946684800 +0000'
author_date = date
author_email = email
author_name = name
committer_date = date
committer_email = email
committer_name = name
message = b'a'
# ASCII hex of parents.
parents = ()

# Blob.
blob_sha_ascii, blob_sha = util.save_object(b'blob', blob_content)
# Check sha matches Git.
blob_sha_git = util.get_git_hash_object(b'blob', blob_content)
assert blob_sha_ascii == blob_sha_git

# Tree.
tree_sha_ascii, tree_sha, tree_content = util.save_tree_object(blob_mode, blob_basename, blob_sha)
# Check sha matches Git.
tree_sha_git = util.get_git_hash_object(b'tree', tree_content)
assert tree_sha_ascii == tree_sha_git

# Commit.
commit_sha_ascii, commit_sha, commit_content = util.save_commit_object(
        tree_sha_ascii, parents,
        author_name, author_email, author_date,
        committer_name, committer_email, committer_date,
Пример #25
0
# tokenize corpus
tokenized_train = [tn.tokenizer.tokenize(text) for text in train_corpus]
tokenized_test = [tn.tokenizer.tokenize(text) for text in test_corpus]

# # # FastText
# # Train and save FastText
ft_num_features = 1000
ft_model = FastText(tokenized_train,
                    size=ft_num_features,
                    window=20,
                    min_count=2,
                    sample=1e-3,
                    sg=1,
                    iter=5,
                    workers=10)
util.save_object(ft_model, FAST_TEXT_SAVE_PATH)
# # Load FastText
# ft_model = util.load_object(FAST_TEXT_SAVE_PATH)

# generate averaged word vector features from word2vec model
avg_ft_train_features = document_vectorize(corpus=tokenized_train,
                                           model=ft_model,
                                           num_features=ft_num_features)
avg_ft_test_features = document_vectorize(corpus=tokenized_test,
                                          model=ft_model,
                                          num_features=ft_num_features)

print('FastText model:> Train features shape:', avg_ft_train_features.shape,
      ' Test features shape:', avg_ft_test_features.shape)

# # pack data in one class
alpha_norm = Alpha(1, {0: 7}, {0: 8})
alpha_reduce = Alpha(1, {0: 7}, {0: 8})

for edge in alpha_norm.parameters[0][0]:
    alpha_norm.parameters[0][0][edge].requires_grad = False
for edge in alpha_reduce.parameters[0][0]:
    alpha_reduce.parameters[0][0][edge].requires_grad = False

# Set to DARTS Alpha Normal
alpha_norm.parameters[0][0][(0, 2)][2] = 1
alpha_norm.parameters[0][0][(0, 3)][2] = 1
alpha_norm.parameters[0][0][(0, 4)][2] = 1
alpha_norm.parameters[0][0][(1, 2)][2] = 1
alpha_norm.parameters[0][0][(1, 3)][2] = 1
alpha_norm.parameters[0][0][(1, 4)][8] = 1
alpha_norm.parameters[0][0][(1, 5)][8] = 1
alpha_norm.parameters[0][0][(2, 5)][5] = 1

# Set to DARTS Alpha Reduce
alpha_reduce.parameters[0][0][(0, 2)][1] = 1
alpha_reduce.parameters[0][0][(0, 4)][1] = 1
alpha_reduce.parameters[0][0][(1, 2)][1] = 1
alpha_reduce.parameters[0][0][(1, 3)][1] = 1
alpha_reduce.parameters[0][0][(1, 5)][1] = 1
alpha_reduce.parameters[0][0][(2, 3)][8] = 1
alpha_reduce.parameters[0][0][(2, 4)][8] = 1
alpha_reduce.parameters[0][0][(2, 5)][8] = 1
util.save_object(alpha_norm, "darts_alpha/best/alpha_normal.pkl")
util.save_object(alpha_reduce, "darts_alpha/best/alpha_reduce.pkl")
Пример #27
0
 def pickle(self):
     save_object(self.pickle_path, self.criteria)
Пример #28
0
data_df = load_20newsgroups()

train_corpus, test_corpus, train_label_names, \
test_label_names = train_test_split(np.array(data_df['Clean Article']),
                                    np.array(data_df['Target Name']),
                                    test_size=0.33, random_state=42)

# Reformat and save data for FastText
ft_train_data_formatted = ''
for i in range(0, len(train_corpus)):
    if i in data_df.index:
        ft_train_data_formatted += '__label__' + train_label_names[i] + ' ' + \
                                   train_corpus[i] + '\n'
util.save_object(
    ft_train_data_formatted, TOPIC_CLASSIFICATION_DATA_PATH +
    DATASET_NAME_20newsgroups + '_fasttext_train_formatted.txt')

ft_test_data_formatted = ''
for i in range(0, len(test_corpus)):
    if i in data_df.index:
        ft_test_data_formatted += '__label__' + test_label_names[i] + ' ' + \
                                  test_corpus[i] + '\n'
util.save_object(
    ft_test_data_formatted, TOPIC_CLASSIFICATION_DATA_PATH +
    DATASET_NAME_20newsgroups + '_fasttext_test_formatted.txt')

# Load data for FastText
# ft_data_formatted = util.load_object(TOPIC_CLASSIFICATION_DATA_PATH +
#                                      DATASET_NAME_20newsgroups +
#                                      '_fasttext_formatted.txt')
Пример #29
0
Файл: run.py Проект: yyht/rrws
def run(args):
    # set up args
    if args.cuda and torch.cuda.is_available():
        device = torch.device('cuda')
        args.cuda = True
    else:
        device = torch.device('cpu')
        args.cuda = False
    if args.train_mode == 'thermo' or args.train_mode == 'thermo_wake':
        partition = util.get_partition(args.num_partitions,
                                       args.partition_type, args.log_beta_min,
                                       device)
    util.print_with_time('device = {}'.format(device))
    util.print_with_time(str(args))

    # save args
    save_dir = util.get_save_dir()
    args_path = util.get_args_path(save_dir)
    util.save_object(args, args_path)

    # data
    binarized_mnist_train, binarized_mnist_valid, binarized_mnist_test = \
        data.load_binarized_mnist(where=args.where)
    data_loader = data.get_data_loader(binarized_mnist_train, args.batch_size,
                                       device)
    valid_data_loader = data.get_data_loader(binarized_mnist_valid,
                                             args.valid_batch_size, device)
    test_data_loader = data.get_data_loader(binarized_mnist_test,
                                            args.test_batch_size, device)
    train_obs_mean = torch.tensor(np.mean(binarized_mnist_train, axis=0),
                                  device=device,
                                  dtype=torch.float)

    # init models
    util.set_seed(args.seed)
    generative_model, inference_network = util.init_models(
        train_obs_mean, args.architecture, device)

    # optim
    optim_kwargs = {'lr': args.learning_rate}

    # train
    if args.train_mode == 'ws':
        train_callback = train.TrainWakeSleepCallback(
            save_dir, args.num_particles * args.batch_size, test_data_loader,
            args.eval_num_particles, args.logging_interval,
            args.checkpoint_interval, args.eval_interval)
        train.train_wake_sleep(generative_model, inference_network,
                               data_loader, args.num_iterations,
                               args.num_particles, optim_kwargs,
                               train_callback)
    elif args.train_mode == 'ww':
        train_callback = train.TrainWakeWakeCallback(
            save_dir, args.num_particles, test_data_loader,
            args.eval_num_particles, args.logging_interval,
            args.checkpoint_interval, args.eval_interval)
        train.train_wake_wake(generative_model, inference_network, data_loader,
                              args.num_iterations, args.num_particles,
                              optim_kwargs, train_callback)
    elif args.train_mode == 'reinforce' or args.train_mode == 'vimco':
        train_callback = train.TrainIwaeCallback(
            save_dir, args.num_particles, args.train_mode, test_data_loader,
            args.eval_num_particles, args.logging_interval,
            args.checkpoint_interval, args.eval_interval)
        train.train_iwae(args.train_mode, generative_model, inference_network,
                         data_loader, args.num_iterations, args.num_particles,
                         optim_kwargs, train_callback)
    elif args.train_mode == 'thermo':
        train_callback = train.TrainThermoCallback(
            save_dir, args.num_particles, partition, test_data_loader,
            args.eval_num_particles, args.logging_interval,
            args.checkpoint_interval, args.eval_interval)
        train.train_thermo(generative_model, inference_network, data_loader,
                           args.num_iterations, args.num_particles, partition,
                           optim_kwargs, train_callback)
    elif args.train_mode == 'thermo_wake':
        train_callback = train.TrainThermoWakeCallback(
            save_dir, args.num_particles, test_data_loader,
            args.eval_num_particles, args.logging_interval,
            args.checkpoint_interval, args.eval_interval)
        train.train_thermo_wake(generative_model, inference_network,
                                data_loader, args.num_iterations,
                                args.num_particles, partition, optim_kwargs,
                                train_callback)

    # eval validation
    train_callback.valid_log_p, train_callback.valid_kl = train.eval_gen_inf(
        generative_model, inference_network, valid_data_loader,
        args.eval_num_particles)

    # save models and stats
    util.save_checkpoint(save_dir,
                         iteration=None,
                         generative_model=generative_model,
                         inference_network=inference_network)
    stats_path = util.get_stats_path(save_dir)
    util.save_object(train_callback, stats_path)
Пример #30
0
def bulid_candidate_words(data, stop_nfile, candidate_save_path, candidata_pos={}, first_sentence_count=30, last_sentence_count=20):
    # ID 标题 文本内容
    stop_words = util.stopwordslist(stop_nfile)
    # load corpus and model
    corpus_dict = util.load_object(corpora_dict_path)
    corpus = corpora.MmCorpus(corpus_path)
    tfidf_model = models.TfidfModel.load(tfidf_path)
    lda_model = models.LdaModel.load(lda_path)
    lsi_model = models.LsiModel.load(lsi_path)

    candidate_words = []
    for index, row in data.iterrows():
        title = str(row['title']).strip()
        doc = str(row['doc']).strip()
        candidate_word = {} # 该行记录的候选词key为word,value为id对应的特征(选择的10个特征)
        # doc
        words_doc = list(pseg.cut(doc, HMM=True)) #[(word, flag)]
        # title
        words_title = list(pseg.cut(title, HMM=True))

        # 去除停用词
        words_doc = [(word, pos) for word,pos in words_doc if word not in stop_words]
        words_title = [(word, pos) for word,pos in words_title if word not in stop_words]

        doc_len = len(words_doc)  # 统计去除停用词后的doc长度
        title_len = len(words_title)
        for word_index,(word,pos) in enumerate(words_doc):
            if pos in candidata_pos and len(word) > 1:
                # 特征的最后三项分别:features[-3]doc长度,features[-2]纪录候选词的首次出现位置,features[-1]最后一次出现的位置
                if word in candidate_word:
                    word_features = candidate_word[word]
                    word_features[-1] = (word_index+1)
                    candidate_word[word] = word_features
                    continue
                else:
                    features = [0] * 14
                    features[-3] = doc_len
                    # feature 1 词性
                    features[0] = candidata_pos[pos]
                    # feature 2 候选词首次出现的位置
                    if doc_len == 0:
                        firoc = 0.
                    else:
                        firoc = (word_index+1)/float(doc_len)
                    features[1] = firoc
                    features[-2] = (word_index+1) # 首次出现的位置
                    # feature 3 候选词的长度
                    features[2] = len(word)
                    # feature 4 候选词为的字符都是数字或者字母组成
                    if util.is_contain_char_num(word):
                        features[3] = 1
                    # feature 5 候选词对应的tfidf
                    id = corpus_dict.token2id.get(word, len(corpus_dict.token2id)+1)
                    if id == len(corpus_dict.token2id)+1:
                        features[4] = 1e-8
                    else:
                        for (w_id, tfidf) in tfidf_model[corpus[index]]:
                            if id == w_id:
                                features[4] = tfidf
                                break
                    # feature 6 第一句中候选词出现的次数
                    first_sentence = words_doc[:first_sentence_count]
                    features[5] = util.get_count_sentence(word,first_sentence)
                    # feature 7 最后一句中候选词出现的次数[-20:]
                    last_sentence = words_doc[-last_sentence_count:]
                    features[6] = util.get_count_sentence(word,last_sentence)
                    # feature 8,9 LDA,LSI:候选词的主题分布与文档的主题分布的相似度
                    single_list = [word]
                    word_corpus = tfidf_model[corpus_dict.doc2bow(single_list)]
                    features[7] = get_topic_sim(lda_model,word_corpus,corpus[index])
                    features[8] = get_topic_sim(lsi_model,word_corpus,corpus[index])
                    # feature 11 词跨度长度由的首次出现位置和最后一次出现的位置和doc长度计算
                    candidate_word[word] = features

        for word_index, (word, pos) in enumerate(words_title):
            if pos in candidata_pos and len(word) > 1:
                if word in candidate_word:
                    word_features = candidate_word[word]
                    # feature 10 是否出现在标题中
                    word_features[9] = 1
                    candidate_word[word] = word_features
                else:
                    features = [0] * 14
                    features[-3] = title_len
                    # feature 1 词性
                    features[0] = candidata_pos[pos]
                    # feature 2 候选词首次出现的位置
                    if title_len == 0:
                        firoc = 0.
                    else:
                        firoc = (word_index + 1) / float(title_len)
                    features[1] = firoc
                    features[-2] = (word_index + 1)  # 首次出现的位置
                    # feature 3 候选词的长度
                    features[2] = len(word)
                    # feature 4 候选词为的字符都是数字或者字母组成
                    if util.is_contain_char_num(word):
                        features[3] = 1
                    # feature 5 候选词对应的tfidf
                    id = corpus_dict.token2id.get(word, len(corpus_dict.token2id) + 1)
                    if id == len(corpus_dict.token2id) + 1:
                        features[4] = 1e-8
                    else:
                        for (w_id, tfidf) in tfidf_model[corpus[index]]:
                            if id == w_id:
                                features[4] = tfidf
                                break
                    # feature 6 第一句中候选词出现的次数
                    first_sentence = words_doc[:first_sentence_count]
                    features[5] = util.get_count_sentence(word, first_sentence)
                    # feature 7 最后一句中候选词出现的次数[-20:]
                    last_sentence = words_doc[-last_sentence_count:]
                    features[6] = util.get_count_sentence(word, last_sentence)
                    # feature 8,9 LDA,LSI:候选词的主题分布与文档的主题分布的相似度
                    single_list = [word]
                    word_corpus = tfidf_model[corpus_dict.doc2bow(single_list)]
                    features[7] = get_topic_sim(lda_model, word_corpus, corpus[index])
                    features[8] = get_topic_sim(lsi_model, word_corpus, corpus[index])
                    # feature 10 是否出现在标题中
                    features[9] = 1
                    # feature 11 词跨度长度由的首次出现位置和最后一次出现的位置和doc长度计算
                    candidate_word[word] = features
        candidate_words.append(candidate_word)
        # save
        if index % 2000 == 0:
            print('deal with sentence %d' % index)

    # data['candidate_words'] = candidate_words
    # data.to_csv(data_candidate_path, sep='\001', header=None, index=None)
    util.save_object(candidate_words,candidate_save_path)
Пример #31
0
def run(args):
    # set up args
    args.device = None
    if args.cuda and torch.cuda.is_available():
        args.device = torch.device('cuda')
    else:
        args.device = torch.device('cpu')
    args.num_mixtures = 20
    if args.init_near:
        args.init_mixture_logits = np.ones(args.num_mixtures)
    else:
        args.init_mixture_logits = np.array(
            list(reversed(2 * np.arange(args.num_mixtures))))
    args.softmax_multiplier = 0.5
    if args.train_mode == 'concrete':
        args.relaxed_one_hot = True
        args.temperature = 3
    else:
        args.relaxed_one_hot = False
        args.temperature = None
    temp = np.arange(args.num_mixtures) + 5
    true_p_mixture_probs = temp / np.sum(temp)
    args.true_mixture_logits = \
        np.log(true_p_mixture_probs) / args.softmax_multiplier
    util.print_with_time(str(args))

    # save args
    model_folder = util.get_model_folder()
    args_filename = util.get_args_path(model_folder)
    util.save_object(args, args_filename)

    # init models
    util.set_seed(args.seed)
    generative_model, inference_network, true_generative_model = \
        util.init_models(args)
    if args.train_mode == 'relax':
        control_variate = models.ControlVariate(args.num_mixtures)

    # init dataloader
    obss_data_loader = torch.utils.data.DataLoader(
        true_generative_model.sample_obs(args.num_obss),
        batch_size=args.batch_size,
        shuffle=True)

    # train
    if args.train_mode == 'mws':
        train_callback = train.TrainMWSCallback(model_folder,
                                                true_generative_model,
                                                args.logging_interval,
                                                args.checkpoint_interval,
                                                args.eval_interval)
        train.train_mws(generative_model, inference_network, obss_data_loader,
                        args.num_iterations, args.mws_memory_size,
                        train_callback)
    if args.train_mode == 'ws':
        train_callback = train.TrainWakeSleepCallback(
            model_folder, true_generative_model,
            args.batch_size * args.num_particles, args.logging_interval,
            args.checkpoint_interval, args.eval_interval)
        train.train_wake_sleep(generative_model, inference_network,
                               obss_data_loader, args.num_iterations,
                               args.num_particles, train_callback)
    elif args.train_mode == 'ww':
        train_callback = train.TrainWakeWakeCallback(model_folder,
                                                     true_generative_model,
                                                     args.num_particles,
                                                     args.logging_interval,
                                                     args.checkpoint_interval,
                                                     args.eval_interval)
        train.train_wake_wake(generative_model, inference_network,
                              obss_data_loader, args.num_iterations,
                              args.num_particles, train_callback)
    elif args.train_mode == 'dww':
        train_callback = train.TrainDefensiveWakeWakeCallback(
            model_folder, true_generative_model, args.num_particles, 0.2,
            args.logging_interval, args.checkpoint_interval,
            args.eval_interval)
        train.train_defensive_wake_wake(0.2, generative_model,
                                        inference_network, obss_data_loader,
                                        args.num_iterations,
                                        args.num_particles, train_callback)
    elif args.train_mode == 'reinforce' or args.train_mode == 'vimco':
        train_callback = train.TrainIwaeCallback(
            model_folder, true_generative_model, args.num_particles,
            args.train_mode, args.logging_interval, args.checkpoint_interval,
            args.eval_interval)
        train.train_iwae(args.train_mode, generative_model, inference_network,
                         obss_data_loader, args.num_iterations,
                         args.num_particles, train_callback)
    elif args.train_mode == 'concrete':
        train_callback = train.TrainConcreteCallback(
            model_folder, true_generative_model, args.num_particles,
            args.num_iterations, args.logging_interval,
            args.checkpoint_interval, args.eval_interval)
        train.train_iwae(args.train_mode, generative_model, inference_network,
                         obss_data_loader, args.num_iterations,
                         args.num_particles, train_callback)
    elif args.train_mode == 'relax':
        train_callback = train.TrainRelaxCallback(model_folder,
                                                  true_generative_model,
                                                  args.num_particles,
                                                  args.logging_interval,
                                                  args.checkpoint_interval,
                                                  args.eval_interval)
        train.train_relax(generative_model, inference_network, control_variate,
                          obss_data_loader, args.num_iterations,
                          args.num_particles, train_callback)

    # save models and stats
    util.save_models(generative_model, inference_network, model_folder)
    if args.train_mode == 'relax':
        util.save_control_variate(control_variate, model_folder)
    stats_filename = util.get_stats_path(model_folder)
    util.save_object(train_callback, stats_filename)