예제 #1
0
def train_model(prefix,
                interpreter: Interpreter,
                io_examples_tr,
                io_examples_val,
                io_examples_test,
                task_info,
                save=False,
                dir_path=None,
                last_recogniser_ti=None,
                last_rnn_ti=None,
                load=False):
    if task_info.task_type == TaskType.Recognise:
        output_type = ProgramOutputType.SIGMOID
    else:
        output_type = ProgramOutputType.INTEGER

    data_loader_tr = interpreter._get_data_loader(io_examples_tr)
    data_loader_val = interpreter._get_data_loader(io_examples_val)
    data_loader_test = interpreter._get_data_loader(io_examples_test)

    program, new_fns_dict, parameters = get_model(task_info,
                                                  dir_path,
                                                  last_recogniser_ti,
                                                  last_rnn_ti,
                                                  load=load)

    new_fns_dict, max_accuracy_val, evaluations_np = interpreter.learn_neural_network_(
        program,
        output_type=output_type,
        new_fns_dict=new_fns_dict,
        trainable_parameters=list(parameters),
        data_loader_tr=data_loader_tr,
        data_loader_val=data_loader_val,
        data_loader_test=data_loader_test)

    max_accuracy_test = interpreter._get_accuracy(program, data_loader_test,
                                                  output_type, new_fns_dict)
    print(max_accuracy_test)

    num_examples = io_examples_tr[0].shape[0] if type(
        io_examples_tr) == tuple else io_examples_tr[0][0].shape[0]

    np.save(
        "{}/_{}__{}evaluations_np.npy".format(dir_path, prefix, num_examples),
        evaluations_np)
    if save:
        for key, value in new_fns_dict.items():
            value.save("{}/Models/".format(dir_path))

    return {"accuracy": max_accuracy_test}
예제 #2
0
def train_summer(type,
                 interpreter: Interpreter,
                 io_examples_tr,
                 io_examples_val,
                 io_examples_test,
                 dir_path,
                 save=False):
    output_type = ProgramOutputType.INTEGER

    if type == "sa" or type == "wt":
        program, new_fns_dict, parameters = get_model_summer(dir_path,
                                                             load=type == "wt")
    else:
        program, new_fns_dict, parameters = get_model_summer_pnn(dir_path)

    # output_type = ProgramOutputType.INTEGER
    data_loader_tr = interpreter._get_data_loader(io_examples_tr)
    data_loader_val = interpreter._get_data_loader(io_examples_val)
    data_loader_test = interpreter._get_data_loader(io_examples_test)

    new_fns_dict, max_accuracy_val, _, evaluations_np = interpreter.learn_neural_network_(
        program,
        output_type=output_type,
        new_fns_dict=new_fns_dict,
        trainable_parameters=list(parameters),
        data_loader_tr=data_loader_tr,
        data_loader_val=data_loader_val,
        data_loader_test=data_loader_test)

    max_accuracy_test = interpreter._get_accuracy(program, data_loader_test,
                                                  output_type, new_fns_dict)
    print(max_accuracy_test)

    # num_examples = io_examples_tr[0].shape[0] if type(io_examples_tr) == tuple else io_examples_tr[0][0].shape[0]
    # np.save("{}/_{}__{}evaluations_np.npy".format(dir_path, prefix, num_examples), evaluations_np)
    if save:
        for key, value in new_fns_dict.items():
            value.save("{}/Models/".format(dir_path))
    return {"accuracy": max_accuracy_test}