Пример #1
0
def run(world, OS = "WIN", test_type = 7, test_option = None, num_test_run = 2, num_subjects = 1, \
        multi_thread = False, en_logging = False, rand_type = 0, #rand_type = 110518233715, \
        perfect_MB = True, perfect_cconv = True, CUthreshold = 0.4, CUNumsThreshold = 0.3, CUinScale = 1.0, \
        tranf_scale = 0.451, learn_alpha = 1.0, auto_run = 2):

    learn_actions = 2
    if( test_option is None ):
        if( test_type == 2 ):
#            test_option = [[0.21,0.63,5],[0.63,0.21,5],[0.12,0.72,5],[0.72,0.12,5]]
            test_option = [[100,0,20]]
            learn_actions = 2
        if( test_type == 3 ):
            test_option = (4,0)
        if( test_type == 4 ):
            test_option = (None,None)
        if( test_type == 5 ):
            test_option = (5,'K')
        if( test_type == 6 ):
            test_option = (4,"AAxB","xB")
        if( test_type == 7 ):
            test_option = [3,2,1]

    conf.OS = OS;                     conf.test_type = test_type;             conf.test_option = test_option
    conf.num_test_run = num_test_run; conf.num_subjects = num_subjects;       conf.learn_alpha = learn_alpha
    conf.CUthreshold = CUthreshold;   conf.CUNumsThreshold = CUNumsThreshold; conf.CUinScale = CUinScale;         
    conf.tranf_scale = tranf_scale;   conf.learn_actions = learn_actions

## ------------------------------------- DEFINE RULE SET ------------------------------------------
    class SpaUNRules:
        def task_init(vis = "A"):
            set(ps_task = "X")

    #    def task_r_init(vis = "R"):
    #        set(ps_task = "R", ps_state = "TRANS1", ps_subtask = "MF")
    #    def task_r_store(ps_tasko = "R", ps_stateo = "TRANS1", scale = 0.5):
    #        set(ps_state = "TRANS1")
    #
    #    def task_v_init(vis = "V"):
    #        set(ps_task = "V", ps_state = "SKIP", ps_subtask = "NON")
    #    def task_v_tr1_2_skp(ps_tasko = "V", ps_stateo = "TRANS1", scale = 0.5):
    #        set(ps_state = "SKIP")
    #
    #    def task_f_init(vis = "F"):
    #        set(ps_task = "F", ps_state = "SKIP", ps_subtask = "NON")
    #    def task_f_tr1_2_tr2(ps_tasko = "F", ps_stateo = "TRANS1", scale = 0.5):
    #        set(ps_state = "TRANS2")
    #    def task_f_tr2_2_skp(ps_tasko = "F", ps_stateo = "TRANS2", scale = 0.5):
    #        set(ps_state = "SKIP")
    #
    #    def task_m_init(vis = "M"):
    #        set(ps_task = "M", ps_state = "TRANS1", ps_subtask = "MF")
    #    def task_m_keep_tr1(ps_tasko = "M", ps_stateo = "TRANS1", scale = 0.5):
    #        set(ps_state = "TRANS1")
    #    def task_m_set_fwd(vis = "P", ps_tasko = "M", scale = 0.5):
    #        set(ps_subtask = "MF")
    #    def task_m_set_bck(vis = "K", ps_tasko = "M", scale = 0.5):
    #        set(ps_subtask = "MB")
    #
    #    def task_r_init(vis = "R"):
    #        set(ps_task = "R", ps_state = "TRANS1", ps_subtask = "MF")
    #    def task_r_keep_tr1(ps_tasko = "R", ps_stateo = "TRANS1", scale = 0.5):
    #        set(ps_state = "TRANS1")
    #
    #    def task_a_init(vis = "A"):
    #        set(ps_task = "A", ps_state = "SKIP", ps_subtask = "NON")
    #    def task_a_tr1_2_tr2(ps_tasko = "A", ps_stateo = "TRANS1", scale = 0.5):
    #        set(ps_state = "TRANS2")
    #    def task_a_keep_tr2(ps_tasko = "A", ps_stateo = "TRANS2", scale = 0.5):
    #        set(ps_state = "TRANS2")
    #    def task_a_set_k(ps_tasko = "A", vis = "K", scale = 0.5):
    #        set(ps_subtask = "AK")
    #    def task_a_set_p(ps_tasko = "A", vis = "P", scale = 0.5):
    #        set(ps_subtask = "AP")
    #
    #    def task_c_init(vis = "C"):
    #        set(ps_task = "C", ps_state = "SKIP", ps_subtask = "NON")
    #    def task_c_set_cnt(ps_tasko = "C", ps_stateo = "TRANS1", scale = 0.5):
    #        set(ps_state = "CNT", ps_subtask = "CNT")
    #    def task_c_nomatch(ps_stateo = "CNT", scale = 1.0):
    ##        match(mem_MB2 != mem_MBCnt)
    #        set(ps_subtask = "CNT")
    #    def task_c_match(ps_stateo = "CNT", scale = 0.5):
    ##        match(mem_MB2 == mem_MBCnt)
    #        set(ps_subtask = "MF")
    #
    #    def task_w_init(vis = "W"):
    #        set(ps_task = "W", ps_state = "VIS", ps_subtask = "MF")
    #

        def task_qm(vis = "QM"):
            set(ps_task = "DEC")
        def task_skp_2_tr1(ps_stateo = "SKIP"):
            set(ps_state = "TRANS1")

        def task_l_init(vis = "TWO", ps_tasko = "X", scale = 0.5):
            set(ps_task = "L", ps_state = "LEARN", ps_subtask = "NON")

        for i in range(conf.learn_actions):
            code = """def task_l_a%d(ps_stateo = "LEARN-TRANS1-TRANS2-SKIP", scale = %f, rand_weights = rand_weights):
                          learn(ps_statea = rand_weights, pred_error = vstr_pred_error)
                          set(ps_subtask = "A%d")""" % (i+1,0.35,i+1)
            exec(code)

## ------------------------------------- END DEFINE RULE SET ------------------------------------------
## ------------------------------------- DEFINE SPA NETWORK ------------------------------------------

    class SpaUNLearn(spa.core.SPA):
        dimensions = conf.num_dim
        align_hrrs = True

        stimulus = ControlModule()
        vis      = VisionModule()
        ps       = ProdSysBufferModule()
    #    mem      = MemoryBufferModule()
    #    trans    = TransformModule()
    #    enc      = EncodingModule()
    #    dec      = DecodingModule()
        vstr     = vStrModule()
        motor    = MotorModule()

        BG       = spa.bg.BasalGanglia(SpaUNRules(), pstc_input = 0.01)
        thalamus = spa.thalamus.Thalamus(bg = BG, pstc_route_input = 0.01, pstc_gate = 0.001, route_scale = 1, \
                                         pstc_inhibit = 0.01, pstc_output = 0.011, pstc_route_output = 0.01, \
                                         mutual_inhibit = 2, quick = False)

## ------------------------------------- END DEFINE SPA NETWORK ------------------------------------------

    if( perfect_MB ):
        conf.MB_mode = "ideal"
    if( perfect_cconv ):
        conf.cconv_mode = "direct"

    if( OS == "WIN" ):
        conf.root_path    = "D:\\fchoo\\Documents\\My Dropbox\\SPA\\Code\\Spaun\\"
        conf.vis_filepath = "D:\\fchoo\\Documents\\My Dropbox\\SPA\\Code\\Digits\\Matlab\\"
    elif( OS == "LIN_G" ):
        conf.root_path    = "/home/ctnuser/fchoo/code/"
        conf.vis_filepath = "/home/ctnuser/fchoo/code/Digits/Matlab/"
    elif( OS == "LIN" ):
        conf.root_path    = "/home/fchoo/Dropbox/SPA/Code/data/"
        conf.vis_filepath = "/home/fchoo/Dropbox/SPA/Code/Digits/Matlab/"

    rand_seed = rand_type
    if( not rand_type == 0 ):
        if( rand_type == 1 ):
            rand_seed = eval(datetime.datetime.today().strftime("%y%m%d%H%M"))
    else:
        rand_seed = eval(datetime.datetime.today().strftime("%y%m%d%H%M%S"))
    PDFTools.setSeed(rand_seed)
    random.seed(rand_seed)

    if( not multi_thread ):
        NodeThreadPool.turnOffMultithreading()
    else:
        NodeThreadPool.setNumThreads(multi_thread)

    datetime_str = datetime.datetime.today().strftime("%y%m%d%H%M%S")
    filename  = "task_"  + vocabs.task_strs[test_type] + str(test_option) + "_" + datetime_str + ".txt"
    logname   = "log_"   + vocabs.task_strs[test_type] + str(test_option) + "_" + datetime_str + ".csv"

    if( not OS == "" ):
        conf.out_file = conf.root_path + filename
    else:
        conf.out_file = ""

    if( en_logging and not OS == "" ):
        conf.log_file = conf.root_path + logname
    else:
        conf.log_file = ""

    for i in range(learn_actions):
        vocabs.subtask_strs.append("A" + str(i+1))

    conf.vocab_data = vocabs.VocabData()
    spaun = SpaUNLearn()

    println(conf.est_runtime)

    # Set default vocabularies (for interactive mode)
    hrr.Vocabulary.defaults[conf.num_dim]           = conf.vocab_data.vis_vocab
    hrr.Vocabulary.defaults[conf.vocab_data.nums_dim]    = conf.vocab_data.nums_vocab
    hrr.Vocabulary.defaults[conf.vocab_data.state_dim]   = conf.vocab_data.state_vocab
    hrr.Vocabulary.defaults[conf.vocab_data.task_dim]    = conf.vocab_data.task_vocab
    hrr.Vocabulary.defaults[conf.vocab_data.subtask_dim] = conf.vocab_data.subtask_vocab

    # Raw visual output vocab (for debug)
    vis_raw_vocab = hrr.Vocabulary(conf.vis_dim, max_similarity = 0.05, include_pairs = False)
    list_strs = read_csv(conf.vis_filepath + conf.sym_list_filename, True)
    raw_vecs  = read_csv(conf.vis_filepath + conf.mu_filename)
    for i,list_str in enumerate(list_strs):
        vis_raw_vocab.add(list_str[0], hrr.HRR(data = raw_vecs[i]))
    hrr.Vocabulary.defaults[conf.vis_dim]           = vis_raw_vocab

    if( en_logging ):
        wtfNode = nef.WriteToFileNode("wtf", conf.log_file, spaun.net, \
                                      conf.vocab_data.vis_vocab, log_interval = 0.01, pstc = 0.01)
        wtfNode.addValueTermination("dec", "motor_go")
        wtfNode.addValueTermination("dec", "stimulus_cont")

    if( auto_run > 0 ):
        if( auto_run > 1 ):
            spaun.net.view(play = conf.est_runtime)
        else:
            spaun.net.network.simulator.resetNetwork(False, False)
            spaun.net.network.simulator.run(0, conf.est_runtime, 0.001, False)