Ejemplo n.º 1
0
def getBaseLibrary():
    libSynth = FnLibrary()

    libSynth.addItems(
        get_items_from_repo([
            'compose', 'map_l', 'fold_l', 'conv_l', 'conv_g', 'map_g',
            'fold_g', 'zeros', 'repeat'
        ]))
    return libSynth
Ejemplo n.º 2
0
 def mkDefaultLib():
     lib = FnLibrary()
     lib.addItems(
         get_items_from_repo([
             'compose',
             # 'map_l', 'fold_l', 'conv_l',
             'conv_g',
             'map_g',
             'fold_g',
             'zeros',
             'repeat'
         ]))
     return lib
Ejemplo n.º 3
0

if __name__ == '__main__':
    results_dir = str(sys.argv[1])
    # results_dir = "Results_maze_baselines"
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    just_testing = False
    epochs_cnn = 1  # 000
    epochs_nav = 10
    batch_size = 150

    lib = FnLibrary()
    lib.addItems(
        get_items_from_repo(
            ['flatten_2d_list', 'map_g', 'compose', 'repeat', 'conv_g']))

    interpreter = Interpreter(lib, epochs=1, batch_size=batch_size)
    # interpreter.epochs = epochs_cnn
    # res1 = _train_s2t1(results_dir)
    # print("res1: {}".format(res1["accuracy"]))

    interpreter.epochs = epochs_nav
    res2 = _train_s2t2(results_dir, "s2t1_cnn", "s2t1_mlp")
    print("res2: {}".format(res2["accuracy"]))

    interpreter.epochs = epochs_nav
    res3 = _train_s2t3(results_dir, "s2t2_cnn", "s2t2_mlp", "s2t2_conv_g")
    print("res3: {}".format(res3["accuracy"]))
Ejemplo n.º 4
0
    if save:
        result["new_fns_dict"][name_cnn].save(results_directory)
        result["new_fns_dict"][name_mlp].save(results_directory)
        result["new_fns_dict"][name_rnn].save(results_directory)
    return result


if __name__ == '__main__':
    results_dir = str(sys.argv[1])
    # results_dir = "Results_maze_baselines"
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    just_testing = False
    epochs_cnn = 1000
    epochs_nav = 10
    batch_size = 150

    lib = FnLibrary()
    lib.addItems(get_items_from_repo(['flatten_2d_list', 'map_g', 'compose']))

    interpreter = Interpreter(lib, epochs=1, batch_size=batch_size)
    # interpreter.epochs = epochs_cnn
    # res1 = _train_s1t1(results_dir)
    # print("res1: {}".format(res1["accuracy"]))

    interpreter.epochs = epochs_nav
    res2 = _train_s1t2(results_dir, "s1t1_cnn", "s1t1_mlp")
    print("res2: {}".format(res2["accuracy"]))