Exemple #1
0
def get_synth_lib():
    libSynth = FnLibrary()
    A = PPSortVar('A')
    B = PPSortVar('B')
    C = PPSortVar('C')

    tr5 = mkRealTensorSort([5])
    tb5 = mkBoolTensorSort([5])
    ti5 = mkIntTensorSort([5])
    ppint = PPInt()

    repeatEnum = PPEnumSort(10, 10)

    libSynth.addItems([
        PPLibItem('compose', func(func(B, C), func(A, B), func(A, C)), None),
        PPLibItem('map_l', func(func(A, B), func(lst(A), lst(B))), None),
        PPLibItem('fold_l', func(func(B, A, B), B, func(lst(A), B)), None),
        PPLibItem('conv_l', func(func(lst(A), B), func(lst(A), lst(B))), None),
        PPLibItem('conv_g', func(func(lst(A), B), func(graph(A), graph(B))),
                  None),
        PPLibItem('map_g', func(func(A, B), func(graph(A), graph(B))), None),
        PPLibItem('fold_g', func(func(B, A, B), B, func(graph(A), B)), None),
        PPLibItem('zeros', func(PPDimVar('a'), mkRealTensorSort([1, 'a'])),
                  None),
        PPLibItem('repeat', func(repeatEnum, func(A, A), func(A, A)), None),
        PPLibItem(
            'regress_speed_mnist',
            func(mkRealTensorSort([1, 3, 32, 32]), mkRealTensorSort([1, 2])),
            None),

        # PPLibItem('nav_mnist', func(mkGraphSort(mkRealTensorSort([1, 3, 32, 32])),
        #                             mkGraphSort(mkRealTensorSort([1, 2]))), None),
    ])

    return libSynth
Exemple #2
0
def getLib():
    libSynth = FnLibrary()
    A = PPSortVar('A')
    B = PPSortVar('B')
    C = PPSortVar('C')

    tr5 = mkRealTensorSort([5])
    tb5 = mkBoolTensorSort([5])
    ti5 = mkIntTensorSort([5])
    ppint = PPInt()

    cnts = PPEnumSort(2, 50)

    libSynth.addItems([
        PPLibItem('map', func(func(A, B), func(lst(A), lst(B))), None),
        PPLibItem('fold', func(func(B, A, B), B, func(lst(A), B)), None),
        PPLibItem('conv', func(func(A, lst(A), A), func(lst(A), lst(A))),
                  None),
        PPLibItem('compose', func(func(B, C), func(A, B), func(A, C)), None),
        PPLibItem('repeat', func(cnts, func(A, A), func(A, A)), None),
        PPLibItem('zeros', func(PPDimVar('a'), mkRealTensorSort([1, 'a'])),
                  None),
        PPLibItem('nn_fun_0', func(tr5, tr5), None),
        PPLibItem('nn_fun_1', func(tr5, tb5), None),
        PPLibItem('nn_fun_2', func(tb5, ti5), None),
    ])
    return libSynth
Exemple #3
0
def getBaseLibrary():
    libSynth = FnLibrary()

    libSynth.addItems(
        get_items_from_repo([
            'compose', 'map_l', 'fold_l', 'conv_l', 'conv_g', 'map_g',
            'fold_g', 'zeros', 'repeat'
        ]))
    return libSynth
Exemple #4
0
 def mkDefaultLib():
     lib = FnLibrary()
     lib.addItems(
         get_items_from_repo([
             'compose',
             # 'map_l', 'fold_l', 'conv_l',
             'conv_g',
             'map_g',
             'fold_g',
             'zeros',
             'repeat'
         ]))
     return lib
Exemple #5
0
def test3():
    t123 = mkTensorSort(PPInt(), [1, 2, 3])
    t12 = mkTensorSort(PPInt(), [1, 2])
    libSynth = FnLibrary()
    libSynth.addItems([
        PPLibItem('one', mkFuncSort(t123, t12), None),
    ])
    ioExamples = None

    fnSort = PPFuncSort([t123], t12)
    interpreter = None

    solver = SymbolicSynthesizer(interpreter, libSynth, fnSort, ioExamples)
    solver.setEvaluate(False)
    solution, score = solver.solve()
    print(solution)
    print(score)
Exemple #6
0
 def update_library(self, lib: FnLibrary,
                    task_result_single: TaskResultSingle, taskid):
     if self.seq_settings.update_library:
         # Add learned modules to the library
         top_solution = task_result_single.get_top_solution_details()
         if top_solution is not None:
             prog, resDict = top_solution
             unk_sort_map: Dict[str,
                                PPSort] = ASTUtils.getUnkNameSortMap(prog)
             lib_items = [
                 PPLibItem(unk, unk_sort, resDict['new_fns_dict'][unk])
                 for unk, unk_sort in unk_sort_map.items()
             ]
             if lib_items.__len__() > 0:
                 lib.addItems(lib_items)
             # Save the library.
             lib.save1(self.getLibLocation(), taskid)
Exemple #7
0
def test5():
    def mk_recognise_5s():
        res = NetCNN("recognise_5s",
                     input_ch=1,
                     output_dim=1,
                     output_activation=F.sigmoid)
        res.load('../Interpreter/Models/is5_classifier.pth.tar')
        return res

    libSynth = FnLibrary()

    t = PPSortVar('T')
    t1 = PPSortVar('T1')
    t2 = PPSortVar('T2')

    libSynth.addItems([
        PPLibItem(
            'recognise_5s',
            mkFuncSort(mkTensorSort(PPReal(), ['a', 1, 28, 28]),
                       mkTensorSort(PPReal(), ['a', 1])), mk_recognise_5s()),
        PPLibItem(
            'map',
            mkFuncSort(mkFuncSort(t1, t2), mkListSort(t1), mkListSort(t2)),
            pp_map),
    ])

    ioExamples = None
    img = mkRealTensorSort([1, 1, 28, 28])
    imgList = mkListSort(img)
    isFive = mkRealTensorSort([1, 1])
    imgToIsFive = mkFuncSort(img, isFive)
    isFiveList = mkListSort(isFive)

    fnSort = mkFuncSort(imgList, isFiveList)

    interpreter = None
    """targetProg = lambda inputs: map(lib.recognise_5s, inputs)"""

    solver = SymbolicSynthesizer(interpreter, libSynth, fnSort, ioExamples,
                                 ioExamples)
    solver.setEvaluate(False)
    # TODO: use "search" instead of "solve"
    solution, score = solver.solve()
    print(solution)
    print(score)
Exemple #8
0
def test1():
    intSort = PPInt()
    boolSort = PPBool()

    libSynth = FnLibrary()
    libSynth.addItems([
        PPLibItem('itob', mkFuncSort(intSort, boolSort), None),
    ])
    ioExamples = None

    fnSort = PPFuncSort([intSort], boolSort)
    interpreter = None

    solver = SymbolicSynthesizer(interpreter, libSynth, fnSort, ioExamples)
    solver.setEvaluate(False)
    targetProg = PPVar('lib.itobX')
    count = solver.search(targetProg, 100)
    assert count == -1
Exemple #9
0
def xtest2():
    intSort = PPInt()
    boolSort = PPBool()
    libSynth = FnLibrary()
    libSynth.addItems([
        PPLibItem('itob', mkFuncSort(intSort, boolSort), None),
    ])
    ioExamples = None

    fnSort = PPFuncSort([intSort], boolSort)
    interpreter = None

    solver = SymbolicSynthesizer(interpreter, libSynth, fnSort, ioExamples)
    solver.setEvaluate(False)
    targetProg = PPLambda(params=[PPVarDecl(name='x1', sort=PPInt())],
                          body=PPFuncApp(fn=PPVar('lib.itob'),
                                         args=[PPVar(name='x1')]))
    count = solver.search(targetProg, 100)
    assert count >= 0
Exemple #10
0
def test_zeros():
    # IO Examples
    train, val = split_into_train_and_validation(0, 10)
    train_io_examples = get_batch_count_iseven(digits_to_count=[5],
                                               count_up_to=10,
                                               batch_size=100,
                                               digit_dictionary=train)
    val_io_examples = get_batch_count_iseven(digits_to_count=[5],
                                             count_up_to=10,
                                             batch_size=20,
                                             digit_dictionary=val)

    def mk_recognise_5s():
        res = NetCNN("recognise_5s",
                     input_ch=1,
                     output_dim=1,
                     output_activation=F.sigmoid)
        res.load('../Interpreter/Models/is5_classifier.pth.tar')
        return res

    # Library
    libSynth = FnLibrary()

    t = PPSortVar('T')
    t1 = PPSortVar('T1')
    t2 = PPSortVar('T2')

    libSynth.addItems([
        PPLibItem('zeros', mkFuncSort(PPDimVar('a'),
                                      mkRealTensorSort([1, 'a'])), pp_map),
        # PPLibItem('zeros2', mkFuncSort(PPDimVar('a'), PPDimVar('b'), mkRealTensorSort(['a', 'b'])), pp_map),
    ])

    fnSort = mkFuncSort(PPDimConst(2), mkRealTensorSort([2]))

    interpreter = Interpreter(libSynth)
    solver = SymbolicSynthesizer(interpreter, libSynth, fnSort,
                                 train_io_examples, val_io_examples)
    solver.setEvaluate(False)
    solution, score = solver.solve()
Exemple #11
0
def xtest4():
    t123 = mkTensorSort(PPInt(), [1, 2, 3])
    t333 = mkTensorSort(PPInt(), [3, 3, 3])
    tabc = mkTensorSort(PPInt(), ['a', 'b', 'c'])
    tabcc = mkTensorSort(PPInt(), ['a', 'b', 'c', 'd'])
    tabcd = mkTensorSort(PPInt(), ['a', 'b', 'c', 'd'])
    tddd = mkTensorSort(PPInt(), ['d', 'd', 'd'])

    libSynth = FnLibrary()
    libSynth.addItems([
        PPLibItem('one', mkFuncSort(tabc, tabcc), None),
    ])
    libSynth.addItems([
        PPLibItem('two', mkFuncSort(tabcd, tddd), None),
    ])

    ioExamples = None

    fnSort = PPFuncSort([t123], t333)

    interpreter = None

    solver = SymbolicSynthesizer(interpreter, libSynth, fnSort, ioExamples)
    solver.setEvaluate(False)
    solution, score = solver.solve()
    print(solution)
    print(score)
def main():
    for_realz = False
    if for_realz:
        data_size_tr = 6000
        data_size_val = 2100
        list_lengths_tr = [2, 3, 4, 5]
        list_lengths_val = [6, 7, 8]
        num_epochs = 20
    else:
        data_size_tr = 150  # 12000
        data_size_val = 150  # 2100
        list_lengths_tr = [1]  # [2, 3, 4, 5]
        list_lengths_val = [1]  # [6, 7, 8]
        num_epochs = 1  # 20

    lib = FnLibrary()
    addImageFunctionsToLibrary(lib, load_recognise_5s=False)
    interpreter = Interpreter(
        lib, batch_size=150,
        epochs=num_epochs)  #, evaluate_every_n_percent=70)  # 60)

    mnist_data_provider = MNISTDataProvider()
    mnist_dict_train, mnist_dict_val, mnist_dict_test = mnist_data_provider.split_into_train_and_validation(
        0, 12, shuffleFirst=True)

    d1 = 1
    io_examples_tr = mnist_data_provider.get_batch_count_var_len(
        [d1],
        data_size_tr,
        mnist_dict_train,
        list_lengths=list_lengths_tr,
        return_count_int=False)
    io_examples_val = mnist_data_provider.get_batch_count_var_len(
        [d1],
        data_size_val,
        mnist_dict_val,
        list_lengths=list_lengths_val,
        return_count_int=False)
    acc_np = []
    acc_baseline = []
    for i in range(10):
        acc_np.append(
            accuracy_test_np_model(interpreter, io_examples_tr,
                                   io_examples_val))
        acc_baseline.append(
            accuracy_test_baseline_model(interpreter, io_examples_tr,
                                         io_examples_val))

    print("NP average error: {}".format(sum(acc_np) / len(acc_np)))
    print("New average error: {}".format(
        sum(acc_baseline) / len(acc_baseline)))
    """
Exemple #13
0
def main(task_id, sequence_str, sequence_name):
    # lib = mk_default_lib()
    lib = FnLibrary()
    addImageFunctionsToLibrary(lib, load_recognise_5s=False)

    task_settings = get_task_settings(settings["dbg_mode"],
                                      settings["dbg_learn_parameters"],
                                      synthesizer=None)

    seq_tasks_info = get_sequence_from_string(sequence_str)
    print("running task {} of the following sequence:".format(task_id + 1))
    print_sequence(seq_tasks_info)

    interpreter = Interpreter(lib, batch_size=150, epochs=task_settings.epochs)

    prefix = "{}_{}".format(sequence_name, task_id)

    run_baseline_task(prefix, interpreter, task_settings,
                      seq_tasks_info[:task_id + 1])
Exemple #14
0
def main(dir_path, type):
    if for_realz:
        settings = {
            "data_size_tr": 6000,
            "data_size_val": 2100,
            "num_epochs": 30,
            "training_data_percentages": [2, 10, 20, 50, 100]
        }
    else:
        settings = {
            "data_size_tr": 150,
            "data_size_val": 150,
            "num_epochs": 1,
            "training_data_percentages": [100]
        }

    lib = FnLibrary()
    addImageFunctionsToLibrary(lib, load_recognise_5s=False)
    interpreter = Interpreter(lib,
                              batch_size=150,
                              epochs=settings["num_epochs"])

    run_ss_classifier(interpreter, type, settings, dir_path)
    run_ss_summer(interpreter, type, settings, dir_path)
def loadLibrary(libDirPath):
    libFilePath = libDirPath + '/' + 'lib.pickle'

    if not os.path.isfile(libFilePath):
        return None

    with open(libFilePath, 'rb') as fh:
        libDict = pickle.load(fh)

    lib = FnLibrary()
    for name, (li, isNN) in libDict.items():
        if isNN:
            modelFilePath = libDirPath + '/' + name + '.pth'
            obj = SaveableNNModule.create_and_load(libDirPath, name)
            lib.addItem(PPLibItem(name, li, obj))
        else:
            lib.addItem(get_item_from_repo(name))

    return lib
Exemple #16
0
        result["new_fns_dict"][name_conv_g].save(results_directory)
    return result


if __name__ == '__main__':
    results_dir = str(sys.argv[1])
    # results_dir = "Results_maze_baselines"
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)

    just_testing = False
    epochs_cnn = 1  # 000
    epochs_nav = 10
    batch_size = 150

    lib = FnLibrary()
    lib.addItems(
        get_items_from_repo(
            ['flatten_2d_list', 'map_g', 'compose', 'repeat', 'conv_g']))

    interpreter = Interpreter(lib, epochs=1, batch_size=batch_size)
    # interpreter.epochs = epochs_cnn
    # res1 = _train_s2t1(results_dir)
    # print("res1: {}".format(res1["accuracy"]))

    interpreter.epochs = epochs_nav
    res2 = _train_s2t2(results_dir, "s2t1_cnn", "s2t1_mlp")
    print("res2: {}".format(res2["accuracy"]))

    interpreter.epochs = epochs_nav
    res3 = _train_s2t3(results_dir, "s2t2_cnn", "s2t2_mlp", "s2t2_conv_g")
Exemple #17
0
def test6():
    t = PPSortVar('T')
    t1 = PPSortVar('T1')
    t2 = PPSortVar('T2')

    def mk_recognise_5s():
        res = NetCNN("recognise_5s",
                     input_ch=1,
                     output_dim=1,
                     output_activation=F.sigmoid)
        res.load('../Interpreter/Models/is5_classifier.pth.tar')
        return res

    libSynth = FnLibrary()

    real_tensor_2d = mkTensorSort(PPReal(), ['a', 'b'])
    libSynth.addItems([
        PPLibItem(
            'recognise_5s',
            mkFuncSort(mkTensorSort(PPReal(), ['a', 1, 28, 28]),
                       mkTensorSort(PPReal(), ['a', 1])), mk_recognise_5s()),
        PPLibItem(
            'map',
            mkFuncSort(mkFuncSort(t1, t2), mkListSort(t1), mkListSort(t2)),
            pp_map),
        PPLibItem('reduce', mkFuncSort(mkFuncSort(t, t, t), mkListSort(t), t),
                  pp_reduce),
        PPLibItem('add',
                  mkFuncSort(real_tensor_2d, real_tensor_2d, real_tensor_2d),
                  lambda x, y: x + y),
    ])

    train, val = split_into_train_and_validation(0, 10)
    val_ioExamples = get_batch_count_iseven(digits_to_count=[5],
                                            count_up_to=10,
                                            batch_size=20,
                                            digit_dictionary=val)

    img = mkRealTensorSort([1, 1, 28, 28])
    isFive = mkRealTensorSort([1, 1])
    imgToIsFive = mkFuncSort(img, isFive)
    imgList = mkListSort(img)
    isFiveList = mkListSort(isFive)

    sumOfFives = mkRealTensorSort([1, 1])

    fnSort = mkFuncSort(imgList, sumOfFives)

    interpreter = Interpreter(libSynth)
    """
    targetProg = 
        lambda inputs. 
            reduce( 
                add, 
                map(lib.recognise_5s, inputs))
    """
    # TODO: use "search" instead of "solve"
    solver = SymbolicSynthesizer(interpreter, libSynth, fnSort, val_ioExamples,
                                 val_ioExamples)
    # solver.setEvaluate(False)
    solution, score = solver.solve()
    print(solution)
    print(score)
Exemple #18
0
def main():
    io_train, io_val = get_io_examples_count_digit_occ(5, 1200, 1200)
    # Task Name: count_digit_occ_5s
    prog = PPFuncApp(
        fn=PPVar(name='lib.compose'),
        args=[
            PPTermUnk(
                name='nn_fun_3254',
                sort=PPFuncSort(args=[
                    PPListSort(param_sort=PPListSort(param_sort=PPTensorSort(
                        param_sort=PPBool(),
                        shape=[PPDimConst(
                            value=1), PPDimConst(value=1)])))
                ],
                                rtpe=PPTensorSort(param_sort=PPReal(),
                                                  shape=[
                                                      PPDimConst(value=1),
                                                      PPDimConst(value=1)
                                                  ]))),
            PPFuncApp(fn=PPVar(name='lib.conv_l'),
                      args=[
                          PPFuncApp(fn=PPVar(name='lib.map_l'),
                                    args=[PPVar(name='lib.nn_fun_1')])
                      ])
        ])
    unkSortMap = {
        'nn_fun_3254':
        PPFuncSort(args=[
            PPListSort(param_sort=PPListSort(param_sort=PPTensorSort(
                param_sort=PPBool(),
                shape=[PPDimConst(
                    value=1), PPDimConst(value=1)])))
        ],
                   rtpe=PPTensorSort(
                       param_sort=PPReal(),
                       shape=[PPDimConst(value=1),
                              PPDimConst(value=1)]))
    }
    lib = FnLibrary()
    lib.addItems([
        PPLibItem(name='compose',
                  sort=PPFuncSort(args=[
                      PPFuncSort(args=[PPSortVar(name='B')],
                                 rtpe=PPSortVar(name='C')),
                      PPFuncSort(args=[PPSortVar(name='A')],
                                 rtpe=PPSortVar(name='B'))
                  ],
                                  rtpe=PPFuncSort(args=[PPSortVar(name='A')],
                                                  rtpe=PPSortVar(name='C'))),
                  obj=None),
        PPLibItem(name='repeat',
                  sort=PPFuncSort(args=[
                      PPEnumSort(start=2, end=50),
                      PPFuncSort(args=[PPSortVar(name='A')],
                                 rtpe=PPSortVar(name='A'))
                  ],
                                  rtpe=PPFuncSort(args=[PPSortVar(name='A')],
                                                  rtpe=PPSortVar(name='A'))),
                  obj=None),
        PPLibItem(name='map_l',
                  sort=PPFuncSort(
                      args=[
                          PPFuncSort(args=[PPSortVar(name='A')],
                                     rtpe=PPSortVar(name='B'))
                      ],
                      rtpe=PPFuncSort(
                          args=[PPListSort(param_sort=PPSortVar(name='A'))],
                          rtpe=PPListSort(param_sort=PPSortVar(name='B')))),
                  obj=None),
        PPLibItem(
            name='fold_l',
            sort=PPFuncSort(args=[
                PPFuncSort(args=[PPSortVar(name='B'),
                                 PPSortVar(name='A')],
                           rtpe=PPSortVar(name='B')),
                PPSortVar(name='B')
            ],
                            rtpe=PPFuncSort(args=[
                                PPListSort(param_sort=PPSortVar(name='A'))
                            ],
                                            rtpe=PPSortVar(name='B'))),
            obj=None),
        PPLibItem(name='conv_l',
                  sort=PPFuncSort(
                      args=[
                          PPFuncSort(args=[
                              PPListSort(param_sort=PPSortVar(name='A'))
                          ],
                                     rtpe=PPSortVar(name='B'))
                      ],
                      rtpe=PPFuncSort(
                          args=[PPListSort(param_sort=PPSortVar(name='A'))],
                          rtpe=PPListSort(param_sort=PPSortVar(name='B')))),
                  obj=None),
        PPLibItem(name='zeros',
                  sort=PPFuncSort(args=[PPDimVar(name='a')],
                                  rtpe=PPTensorSort(param_sort=PPReal(),
                                                    shape=[
                                                        PPDimConst(value=1),
                                                        PPDimVar(name='a')
                                                    ])),
                  obj=None),
        PPLibItem(name='nn_fun_1',
                  sort=PPFuncSort(args=[
                      PPTensorSort(param_sort=PPReal(),
                                   shape=[
                                       PPDimConst(value=1),
                                       PPDimConst(value=1),
                                       PPDimConst(value=28),
                                       PPDimConst(value=28)
                                   ])
                  ],
                                  rtpe=PPTensorSort(param_sort=PPBool(),
                                                    shape=[
                                                        PPDimConst(value=1),
                                                        PPDimConst(value=1)
                                                    ])),
                  obj=None),
        PPLibItem(name='nn_fun_2230',
                  sort=PPFuncSort(args=[
                      PPTensorSort(param_sort=PPReal(),
                                   shape=[
                                       PPDimConst(value=1),
                                       PPDimConst(value=1),
                                       PPDimConst(value=28),
                                       PPDimConst(value=28)
                                   ])
                  ],
                                  rtpe=PPTensorSort(param_sort=PPBool(),
                                                    shape=[
                                                        PPDimConst(value=1),
                                                        PPDimConst(value=1)
                                                    ])),
                  obj=None)
    ])
    fn_sort = PPFuncSort(args=[
        PPListSort(param_sort=PPTensorSort(param_sort=PPReal(),
                                           shape=[
                                               PPDimConst(value=1),
                                               PPDimConst(value=1),
                                               PPDimConst(value=28),
                                               PPDimConst(value=28)
                                           ]))
    ],
                         rtpe=PPTensorSort(
                             param_sort=PPReal(),
                             shape=[PPDimConst(value=1),
                                    PPDimConst(value=1)]))

    print(ReprUtils.repr_py_ann(prog))
    print(ReprUtils.repr_py_sort(lib.items['nn_fun_1'].sort))
    NeuralSynthesizer.is_evaluable(prog)
Exemple #19
0
def main():
    tio, vio = get_io_examples_classify_digits(2000, 200)

    # Task Name: classify_digits
    prog = PPTermUnk(name='nn_fun_cs1cd_1', sort=PPFuncSort(args=[PPTensorSort(param_sort=PPReal(),
                                                                               shape=[PPDimConst(value=1),
                                                                                      PPDimConst(value=1),
                                                                                      PPDimConst(value=28),
                                                                                      PPDimConst(value=28)])],
                                                            rtpe=PPTensorSort(param_sort=PPBool(),
                                                                              shape=[PPDimConst(value=1),
                                                                                     PPDimConst(value=10)])))
    unkSortMap = {'nn_fun_cs1cd_1': PPFuncSort(args=[PPTensorSort(param_sort=PPReal(),
                                                                  shape=[PPDimConst(value=1), PPDimConst(value=1),
                                                                         PPDimConst(value=28), PPDimConst(value=28)])],
                                               rtpe=PPTensorSort(param_sort=PPBool(),
                                                                 shape=[PPDimConst(value=1), PPDimConst(value=10)]))}

    lib = FnLibrary()
    lib.addItems([PPLibItem(name='compose', sort=PPFuncSort(
        args=[PPFuncSort(args=[PPSortVar(name='B')], rtpe=PPSortVar(name='C')),
              PPFuncSort(args=[PPSortVar(name='A')], rtpe=PPSortVar(name='B'))],
        rtpe=PPFuncSort(args=[PPSortVar(name='A')], rtpe=PPSortVar(name='C'))), obj=None), PPLibItem(name='repeat',
                                                                                                     sort=PPFuncSort(
                                                                                                         args=[
                                                                                                             PPEnumSort(
                                                                                                                 start=8,
                                                                                                                 end=10),
                                                                                                             PPFuncSort(
                                                                                                                 args=[
                                                                                                                     PPSortVar(
                                                                                                                         name='A')],
                                                                                                                 rtpe=PPSortVar(
                                                                                                                     name='A'))],
                                                                                                         rtpe=PPFuncSort(
                                                                                                             args=[
                                                                                                                 PPSortVar(
                                                                                                                     name='A')],
                                                                                                             rtpe=PPSortVar(
                                                                                                                 name='A'))),
                                                                                                     obj=None),
                  PPLibItem(name='map_l',
                            sort=PPFuncSort(args=[PPFuncSort(args=[PPSortVar(name='A')], rtpe=PPSortVar(name='B'))],
                                            rtpe=PPFuncSort(args=[PPListSort(param_sort=PPSortVar(name='A'))],
                                                            rtpe=PPListSort(param_sort=PPSortVar(name='B')))),
                            obj=None), PPLibItem(name='fold_l', sort=PPFuncSort(
            args=[PPFuncSort(args=[PPSortVar(name='B'), PPSortVar(name='A')], rtpe=PPSortVar(name='B')),
                  PPSortVar(name='B')],
            rtpe=PPFuncSort(args=[PPListSort(param_sort=PPSortVar(name='A'))], rtpe=PPSortVar(name='B'))), obj=None),
                  PPLibItem(name='conv_l', sort=PPFuncSort(
                      args=[PPFuncSort(args=[PPListSort(param_sort=PPSortVar(name='A'))], rtpe=PPSortVar(name='B'))],
                      rtpe=PPFuncSort(args=[PPListSort(param_sort=PPSortVar(name='A'))],
                                      rtpe=PPListSort(param_sort=PPSortVar(name='B')))), obj=None),
                  PPLibItem(name='zeros', sort=PPFuncSort(args=[PPDimVar(name='a')],
                                                          rtpe=PPTensorSort(param_sort=PPReal(),
                                                                            shape=[PPDimConst(value=1),
                                                                                   PPDimVar(name='a')])), obj=None)])
    fn_sort = PPFuncSort(args=[PPTensorSort(param_sort=PPReal(),
                                            shape=[PPDimConst(value=1), PPDimConst(value=1), PPDimConst(value=28),
                                                   PPDimConst(value=28)])],
                         rtpe=PPTensorSort(param_sort=PPBool(), shape=[PPDimConst(value=1), PPDimConst(value=10)]))

    interpreter = Interpreter(lib, 150)
    res = interpreter.evaluate(program=prog,
                               output_type_s=fn_sort.rtpe,
                               unkSortMap=unkSortMap,
                               io_examples_tr=tio,
                               io_examples_val=vio)
def addImageFunctionsToLibrary(libSynth: FnLibrary, load_recognise_5s=True):
    real_tensor_2d = mkTensorSort(PPReal(), ['a', 'b'])
    bool_tensor_2d = mkTensorSort(PPBool(), ['a', 'b'])
    libSynth.addItems([
        PPLibItem('add',
                  mkFuncSort(real_tensor_2d, real_tensor_2d, real_tensor_2d),
                  pp_add),
        PPLibItem('add1',
                  mkFuncSort(real_tensor_2d, bool_tensor_2d, real_tensor_2d),
                  pp_add),
        PPLibItem(
            'map',
            mkFuncSort(mkFuncSort(t1, t2), mkListSort(t1), mkListSort(t2)),
            pp_map),
        PPLibItem(
            'map2d',
            mkFuncSort(mkFuncSort(t1, t2), mkListSort(mkListSort(t1)),
                       mkListSort(mkListSort(t2))), pp_map2d),
        # question ^ should we transform map's definition into using vectors? is this not enough?
        # we don't know the type of the tensor output, w/o knowing the function.

        # PPLibItem('cat', mkFuncSort(mkTensorSort(PPReal(), ['a', 'b']), mkTensorSort(PPReal(), ['a', 'c']),
        #                            mkTensorSort(PPReal(), ['a', 'd'])), pp_cat),  # TODO: d = b + c
        # Question: can we write 'b+c'? I'm not sure if it's useful
        # Also, the input types don't have to be PPReal, but for not it should suffice to just leave it like this?
        # ^ It can accept a tuple of tensors of different shapes, but maybe we can restrict it to tuple of 2 for now.

        # PPLibItem('zeros', mkFuncSort(PPInt(), mkTensorSort(PPReal(), ['a', 'b']), mkTensorSort(PPReal(), ['a', 'c'])), pp_get_zeros),

        # PPLibItem('zeros', mkFuncSort(PPInt(), PPInt(), mkTensorSort(PPReal(), ['a', 'c'])), pp_get_zeros),
        # 4, [2, 5] -> [2, 4]
        # 7, [2, 5] -> [2, 7]
        # Question: How do we say that the ints are the same number, PPInt() == 'c'
        # Also, The input tensor type doesn't have to be PPReal, can be int or bool as well

        # Also, the input tensor can be of any type, doesn't need to be float
        PPLibItem('zeros', mkFuncSort(PPDimVar('a'), mkRealTensorSort([1,
                                                                       'a'])),
                  pp_get_zeros),
        PPLibItem('reduce_general',
                  mkFuncSort(mkFuncSort(t, t1, t), mkListSort(t1), t, t),
                  pp_reduce),
        PPLibItem('reduce', mkFuncSort(mkFuncSort(t, t, t), mkListSort(t), t),
                  pp_reduce),
        # pp_get_zeros
        # PPLibItem('reduce_with_init_zeros', mkFuncSort(mkFuncSort(t, t1, t), mkListSort(t1), t), pp_reduce_w_zeros_init),
        # Question : the initializer is only optional. How do we encode this information?

        # The following are just test functions for evaluation, not properly typed.
        # ,PPLibItem('mult_range09', mkFuncSort(mkFuncSort(t, t1, t), mkListSort(t1), t, t), get_multiply_by_range09())
        # ,PPLibItem('argmax', mkFuncSort(mkFuncSort(t, t1, t), mkListSort(t1), t, t), argmax)

        # PPLibItem('split', mkFuncSort(PPImageSort(), mkListSort(PPImageSort())), split),
        # PPLibItem('join', mkFuncSort(mkListSort(PPImageSort()), PPImageSort()), None),
    ])
    if load_recognise_5s:
        libSynth.addItems([
            PPLibItem(
                'recognise_5s',
                mkFuncSort(mkTensorSort(PPReal(), ['a', 1, 28, 28]),
                           mkTensorSort(PPBool(), ['a', 1])),
                mk_recognise_5s())
        ])

        # set the neural libraries to evaluation mode
        # TODO: need to make sure we're properly switching between eval and train everywhere
        libSynth.recognise_5s.eval()