예제 #1
0
def onpol_sim(worker, subj_list, output_prefix, model_path):
    params = OptML.get_variables(worker.get_params())
    worker.set_params(params)

    config = tf.ConfigProto(device_count={'GPU': 0})
    with tf.Session(config=config) as sess:

        DLogger.logger().debug("loading mode....")
        load_model(sess, model_path)
        DLogger.logger().debug("finished loading mode.")
        for s in range(len(subj_list)):
            DLogger.logger().debug("parameters {}".format(sess.run(params)))
            for c in subj_list[s]:
                DLogger.logger().debug("subject {} trial {}".format(
                    c['id'], c['block']))
                choices = c['choices']
                output_path = output_prefix + 'sim_' + id_generator(
                    size=7) + '/'

                if not os.path.exists(output_path):
                    os.makedirs(output_path)
                c2 = c.copy()
                c2['option'] = {}
                c2['N'] = c['id']
                pdc = pd.DataFrame(c2, index=[0])
                pdc.to_csv(output_path + 'config.csv', index=False)

                _, _ = Simulator.simulate_env(sess,
                                              worker,
                                              output_path,
                                              choices,
                                              bandit_evn(c['prop0'],
                                                         c['prop1'],
                                                         init_state=None,
                                                         init_action=-1,
                                                         init_reward=0),
                                              greedy=False)
예제 #2
0
파일: off_sims.py 프로젝트: ritwik7/rnn_beh
    # tf.reset_default_graph()
    # worker = GQL.get_instance(2, 2, {})
    # group_log_sigma2, group_mu, ind_log_sigma2, ind_mu = OptMIX.get_variables(worker.get_params())
    # worker.set_params(group_mu)
    # simulate_oscci_mix_QL('gql-mix-opt', 'GQL', worker)
    #
    # tf.reset_default_graph()
    # worker = GQL.get_instance(2, 10, {})
    # group_log_sigma2, group_mu, ind_log_sigma2, ind_mu = OptMIX.get_variables(worker.get_params())
    # worker.set_params(group_mu)
    # simulate_oscci_mix_QL('gql10d-mix-opt', 'GQL10D', worker)

    ############ ML graphs #############################
    tf.reset_default_graph()
    worker = QL.get_instance_without_pser(2, 0.1, 0.2)
    worker.set_params(OptML.get_variables(worker.get_params()))
    simulate_mix_QL('ql-ml-opt', 'QL', worker)

    tf.reset_default_graph()
    worker = QL.get_instance_with_pser(2, 0.1, 0.2, 0.2)
    worker.set_params(OptML.get_variables(worker.get_params()))
    simulate_mix_QL('qlp-ml-opt', 'QLP', worker)

    tf.reset_default_graph()
    worker = GQL.get_instance(2, 2, {})
    worker.set_params(OptML.get_variables(worker.get_params()))
    simulate_mix_QL('gql-ml-opt', 'GQL', worker)

    tf.reset_default_graph()
    worker = GQL.get_instance(2, 1, {})
    worker.set_params(OptML.get_variables(worker.get_params()))
예제 #3
0
 def test_and_save(sess, test, output_folder):
     return OptML.test_and_save("", output_folder, None, sess, test, worker)