def build_generators(t_ind):
    if t_ind==0:
        generator = generators.VarDelayedEstimationTask(max_iter=25001, batch_size=50, n_loc=1, n_in=50, n_out=1,
                                                        stim_dur=25, max_delay=100, resp_dur=25, kappa=2.0,
                                                        spon_rate=0.1, tr_cond='all_gains')
        test_generator = generators.VarDelayedEstimationTask(max_iter=2501, batch_size=50, n_loc=1, n_in=50, n_out=1,
                                                             stim_dur=25, max_delay=100, resp_dur=25, kappa=2.0,
                                                             spon_rate=0.1, tr_cond='all_gains')
    elif t_ind==1:
        generator = generators.VarDelayedEstimationTask(max_iter=25001, batch_size=50,  n_loc=2, n_in=50, n_out=2,
                                                        stim_dur=25, max_delay=100, resp_dur=25, kappa=2.0,
                                                        spon_rate=0.1, tr_cond='all_gains')
        test_generator = generators.VarDelayedEstimationTask(max_iter=2501, batch_size=50, n_loc=2, n_in=50, n_out=2,
                                                             stim_dur=25, max_delay=100, resp_dur=25, kappa=2.0,
                                                             spon_rate=0.1, tr_cond='all_gains')
    elif t_ind==2:
        generator = generators.VarChangeDetectionTask(max_iter=25001, batch_size=50, n_loc=1, n_in=50, n_out=1,
                                                      stim_dur=25, max_delay=100, resp_dur=25, kappa=2.0,
                                                      spon_rate=0.1, tr_cond='all_gains')
        test_generator = generators.VarChangeDetectionTask(max_iter=2501, batch_size=50, n_loc=1, n_in=50, n_out=1,
                                                           stim_dur=25, max_delay=100, resp_dur=25, kappa=2.0,
                                                           spon_rate=0.1, tr_cond='all_gains')
    elif t_ind==4:
        generator = generators.VarGatedDelayedEstimationTask(max_iter=25001, batch_size=50, n_loc=2, n_in=50, n_out=1,
                                                             stim_dur=25, max_delay=100, resp_dur=25, kappa=2.0,
                                                             spon_rate=0.1, tr_cond='all_gains')
        test_generator = generators.VarGatedDelayedEstimationTask(max_iter=2501, batch_size=50, n_loc=2, n_in=50, n_out=1,
                                                                  stim_dur=25, max_delay=100, resp_dur=25, kappa=2.0,
                                                                  spon_rate=0.1, tr_cond='all_gains')
    elif t_ind==6:
        generator  = generators.VarHarvey2012(max_iter=25001, batch_size=50, n_in=50, n_out=1, stim_dur=25, max_delay=100,
                                              resp_dur=25, sigtc=15.0, stim_rate=1.0, spon_rate=0.1)
        test_generator = generators.VarHarvey2012(max_iter=2501, batch_size=50, n_in=50, n_out=1, stim_dur=25, max_delay=100,
                                                  resp_dur=25, sigtc=15.0, stim_rate=1.0, spon_rate=0.1)
    elif t_ind==8:
        generator = generators.VarComparisonTask(max_iter=25001, batch_size=50, n_loc=1, n_in=50, n_out=1, stim_dur=25,
                                                 max_delay=100, resp_dur=25, sig_tc=10.0, spon_rate=0.1, tr_cond='all_gains')
        test_generator = generators.VarComparisonTask(max_iter=2501, batch_size=50, n_loc=1, n_in=50, n_out=1, stim_dur=25,
                                                      max_delay=100, resp_dur=25, sig_tc=10.0, spon_rate=0.1, tr_cond='all_gains')
    return generator, test_generator
Пример #2
0
def build_generators(ExptDict):
    # Unpack common variables
    task = ExptDict["task"]["task_id"]
    n_loc = ExptDict["task"]["n_loc"]
    n_out = ExptDict["task"]["n_out"]
    tr_cond = ExptDict["tr_cond"]
    test_cond = ExptDict["test_cond"]
    n_in = ExptDict["n_in"]
    batch_size = ExptDict["batch_size"]
    stim_dur = ExptDict["stim_dur"]
    delay_dur = ExptDict["delay_dur"]
    resp_dur = ExptDict["resp_dur"]
    kappa = ExptDict["kappa"]
    spon_rate = ExptDict["spon_rate"]
    tr_max_iter = ExptDict["tr_max_iter"]
    test_max_iter = ExptDict["test_max_iter"]

    if task == 'DE1':
        generator = generators.DelayedEstimationTask(max_iter=tr_max_iter,
                                                     batch_size=batch_size,
                                                     n_loc=n_loc,
                                                     n_in=n_in,
                                                     n_out=n_out,
                                                     stim_dur=stim_dur,
                                                     delay_dur=delay_dur,
                                                     resp_dur=resp_dur,
                                                     kappa=kappa,
                                                     spon_rate=spon_rate,
                                                     tr_cond=tr_cond)

        test_generator = generators.DelayedEstimationTask(
            max_iter=test_max_iter,
            batch_size=batch_size,
            n_loc=n_loc,
            n_in=n_in,
            n_out=n_out,
            stim_dur=stim_dur,
            delay_dur=delay_dur,
            resp_dur=resp_dur,
            kappa=kappa,
            spon_rate=spon_rate,
            tr_cond=test_cond)
    elif task == 'DE2':
        generator = generators.DelayedEstimationTask(max_iter=tr_max_iter,
                                                     batch_size=batch_size,
                                                     n_loc=n_loc,
                                                     n_in=n_in,
                                                     n_out=n_out,
                                                     stim_dur=stim_dur,
                                                     delay_dur=delay_dur,
                                                     resp_dur=resp_dur,
                                                     kappa=kappa,
                                                     spon_rate=spon_rate,
                                                     tr_cond=tr_cond)

        test_generator = generators.DelayedEstimationTask(
            max_iter=test_max_iter,
            batch_size=batch_size,
            n_loc=n_loc,
            n_in=n_in,
            n_out=n_out,
            stim_dur=stim_dur,
            delay_dur=delay_dur,
            resp_dur=resp_dur,
            kappa=kappa,
            spon_rate=spon_rate,
            tr_cond=test_cond)
    elif task == 'CD1':
        generator = generators.ChangeDetectionTask(max_iter=tr_max_iter,
                                                   batch_size=batch_size,
                                                   n_loc=n_loc,
                                                   n_in=n_in,
                                                   n_out=n_out,
                                                   stim_dur=stim_dur,
                                                   delay_dur=delay_dur,
                                                   resp_dur=resp_dur,
                                                   kappa=kappa,
                                                   spon_rate=spon_rate,
                                                   tr_cond=tr_cond)

        test_generator = generators.ChangeDetectionTask(max_iter=test_max_iter,
                                                        batch_size=batch_size,
                                                        n_loc=n_loc,
                                                        n_in=n_in,
                                                        n_out=n_out,
                                                        stim_dur=stim_dur,
                                                        delay_dur=delay_dur,
                                                        resp_dur=resp_dur,
                                                        kappa=kappa,
                                                        spon_rate=spon_rate,
                                                        tr_cond=test_cond)
    elif task == 'COMP':
        generator = generators.ComparisonTask(max_iter=tr_max_iter,
                                              batch_size=batch_size,
                                              n_loc=n_loc,
                                              n_in=n_in,
                                              n_out=n_out,
                                              stim_dur=stim_dur,
                                              delay_dur=delay_dur,
                                              resp_dur=resp_dur,
                                              sig_tc=10.0,
                                              spon_rate=spon_rate,
                                              tr_cond=tr_cond)

        test_generator = generators.ComparisonTask(max_iter=test_max_iter,
                                                   batch_size=batch_size,
                                                   n_loc=n_loc,
                                                   n_in=n_in,
                                                   n_out=n_out,
                                                   stim_dur=stim_dur,
                                                   delay_dur=delay_dur,
                                                   resp_dur=resp_dur,
                                                   sig_tc=10.0,
                                                   spon_rate=spon_rate,
                                                   tr_cond=test_cond)
    elif task == 'CD2':
        generator = generators.ChangeDetectionTask(max_iter=tr_max_iter,
                                                   batch_size=batch_size,
                                                   n_loc=n_loc,
                                                   n_in=n_in,
                                                   n_out=n_out,
                                                   stim_dur=stim_dur,
                                                   delay_dur=delay_dur,
                                                   resp_dur=resp_dur,
                                                   kappa=kappa,
                                                   spon_rate=spon_rate,
                                                   tr_cond=tr_cond)

        test_generator = generators.ChangeDetectionTask(max_iter=test_max_iter,
                                                        batch_size=batch_size,
                                                        n_loc=n_loc,
                                                        n_in=n_in,
                                                        n_out=n_out,
                                                        stim_dur=stim_dur,
                                                        delay_dur=delay_dur,
                                                        resp_dur=resp_dur,
                                                        kappa=kappa,
                                                        spon_rate=spon_rate,
                                                        tr_cond=test_cond)
    elif task == 'GDE2':
        generator = generators.GatedDelayedEstimationTask(
            max_iter=tr_max_iter,
            batch_size=batch_size,
            n_loc=n_loc,
            n_in=n_in,
            n_out=n_out,
            stim_dur=stim_dur,
            delay_dur=delay_dur,
            resp_dur=resp_dur,
            kappa=kappa,
            spon_rate=spon_rate,
            tr_cond=tr_cond)

        test_generator = generators.GatedDelayedEstimationTask(
            max_iter=test_max_iter,
            batch_size=batch_size,
            n_loc=n_loc,
            n_in=n_in,
            n_out=n_out,
            stim_dur=stim_dur,
            delay_dur=delay_dur,
            resp_dur=resp_dur,
            kappa=kappa,
            spon_rate=spon_rate,
            tr_cond=test_cond)
    elif task == 'VDE1':
        max_delay = ExptDict["task"]["max_delay"]
        generator = generators.VarDelayedEstimationTask(max_iter=tr_max_iter,
                                                        batch_size=batch_size,
                                                        n_loc=n_loc,
                                                        n_in=n_in,
                                                        n_out=n_out,
                                                        stim_dur=stim_dur,
                                                        max_delay=max_delay,
                                                        resp_dur=resp_dur,
                                                        kappa=kappa,
                                                        spon_rate=spon_rate,
                                                        tr_cond=tr_cond)

        test_generator = generators.VarDelayedEstimationTask(
            max_iter=test_max_iter,
            batch_size=batch_size,
            n_loc=n_loc,
            n_in=n_in,
            n_out=n_out,
            stim_dur=stim_dur,
            max_delay=max_delay,
            resp_dur=resp_dur,
            kappa=kappa,
            spon_rate=spon_rate,
            tr_cond=test_cond)
    elif task == 'Harvey2012':
        sigtc = ExptDict["task"]["sigtc"]
        stim_rate = ExptDict["task"]["stim_rate"]
        generator = generators.Harvey2012(max_iter=tr_max_iter,
                                          batch_size=batch_size,
                                          n_in=n_in,
                                          n_out=n_out,
                                          stim_dur=stim_dur,
                                          delay_dur=delay_dur,
                                          resp_dur=resp_dur,
                                          sigtc=sigtc,
                                          stim_rate=stim_rate,
                                          spon_rate=spon_rate)

        test_generator = generators.Harvey2012(max_iter=test_max_iter,
                                               batch_size=batch_size,
                                               n_in=n_in,
                                               n_out=n_out,
                                               stim_dur=stim_dur,
                                               delay_dur=delay_dur,
                                               resp_dur=resp_dur,
                                               sigtc=sigtc,
                                               stim_rate=stim_rate,
                                               spon_rate=spon_rate)

    elif task == 'Harvey2012Dynamic':
        sigtc = ExptDict["task"]["sigtc"]
        stim_rate = ExptDict["task"]["stim_rate"]
        generator = generators.Harvey2012Dynamic(max_iter=tr_max_iter,
                                                 batch_size=batch_size,
                                                 n_in=n_in,
                                                 n_out=n_out,
                                                 stim_dur=stim_dur,
                                                 delay_dur=delay_dur,
                                                 resp_dur=resp_dur,
                                                 sigtc=sigtc,
                                                 stim_rate=stim_rate,
                                                 spon_rate=spon_rate)

        test_generator = generators.Harvey2012Dynamic(max_iter=test_max_iter,
                                                      batch_size=batch_size,
                                                      n_in=n_in,
                                                      n_out=n_out,
                                                      stim_dur=stim_dur,
                                                      delay_dur=delay_dur,
                                                      resp_dur=resp_dur,
                                                      sigtc=sigtc,
                                                      stim_rate=stim_rate,
                                                      spon_rate=spon_rate)

    elif task == 'Harvey2016':
        sigtc = ExptDict["task"]["sigtc"]
        stim_rate = ExptDict["task"]["stim_rate"]
        epoch_dur = ExptDict["task"]["epoch_dur"]
        n_epochs = ExptDict["task"]["n_epochs"]
        generator = generators.Harvey2016(max_iter=tr_max_iter,
                                          batch_size=batch_size,
                                          n_in=n_in,
                                          n_out=n_out,
                                          n_epochs=n_epochs,
                                          epoch_dur=epoch_dur,
                                          sigtc=sigtc,
                                          stim_rate=stim_rate,
                                          spon_rate=spon_rate)

        test_generator = generators.Harvey2016(max_iter=test_max_iter,
                                               batch_size=batch_size,
                                               n_in=n_in,
                                               n_out=n_out,
                                               n_epochs=n_epochs,
                                               epoch_dur=epoch_dur,
                                               sigtc=sigtc,
                                               stim_rate=stim_rate,
                                               spon_rate=spon_rate)

    elif task == 'SINE':
        alpha = ExptDict["task"]["alpha"]
        generator = generators.SineTask(max_iter=tr_max_iter,
                                        batch_size=batch_size,
                                        n_in=n_in,
                                        n_out=n_out,
                                        stim_dur=stim_dur,
                                        delay_dur=delay_dur,
                                        resp_dur=resp_dur,
                                        alpha=alpha)

        test_generator = generators.SineTask(max_iter=test_max_iter,
                                             batch_size=batch_size,
                                             n_in=n_in,
                                             n_out=n_out,
                                             stim_dur=stim_dur,
                                             delay_dur=delay_dur,
                                             resp_dur=resp_dur,
                                             alpha=alpha)

    return generator, test_generator
Пример #3
0
    offdiag_val = args.sigma_val

    model_list = ['LeInitRecurrent', 'ResidualRecurrent', 'GRURecurrent']

    # Task and model parameters
    model = model_list[0]
    tr_cond = 'all_gains'
    test_cond = 'all_gains'
    n_hid = 500  # number of hidden units

    generator = generators.VarDelayedEstimationTask(max_iter=25001,
                                                    batch_size=50,
                                                    n_loc=1,
                                                    n_in=50,
                                                    n_out=1,
                                                    stim_dur=25,
                                                    max_delay=100,
                                                    resp_dur=25,
                                                    kappa=2.0,
                                                    spon_rate=0.1,
                                                    tr_cond=tr_cond)
    test_generator = generators.VarDelayedEstimationTask(max_iter=2501,
                                                         batch_size=50,
                                                         n_loc=1,
                                                         n_in=50,
                                                         n_out=1,
                                                         stim_dur=25,
                                                         max_delay=100,
                                                         resp_dur=25,
                                                         kappa=2.0,
                                                         spon_rate=0.1,