def train_model(prefix, interpreter: Interpreter, io_examples_tr, io_examples_val, io_examples_test, task_info, save=False, dir_path=None, last_recogniser_ti=None, last_rnn_ti=None, load=False): if task_info.task_type == TaskType.Recognise: output_type = ProgramOutputType.SIGMOID else: output_type = ProgramOutputType.INTEGER data_loader_tr = interpreter._get_data_loader(io_examples_tr) data_loader_val = interpreter._get_data_loader(io_examples_val) data_loader_test = interpreter._get_data_loader(io_examples_test) program, new_fns_dict, parameters = get_model(task_info, dir_path, last_recogniser_ti, last_rnn_ti, load=load) new_fns_dict, max_accuracy_val, evaluations_np = interpreter.learn_neural_network_( program, output_type=output_type, new_fns_dict=new_fns_dict, trainable_parameters=list(parameters), data_loader_tr=data_loader_tr, data_loader_val=data_loader_val, data_loader_test=data_loader_test) max_accuracy_test = interpreter._get_accuracy(program, data_loader_test, output_type, new_fns_dict) print(max_accuracy_test) num_examples = io_examples_tr[0].shape[0] if type( io_examples_tr) == tuple else io_examples_tr[0][0].shape[0] np.save( "{}/_{}__{}evaluations_np.npy".format(dir_path, prefix, num_examples), evaluations_np) if save: for key, value in new_fns_dict.items(): value.save("{}/Models/".format(dir_path)) return {"accuracy": max_accuracy_test}
def train_summer(type, interpreter: Interpreter, io_examples_tr, io_examples_val, io_examples_test, dir_path, save=False): output_type = ProgramOutputType.INTEGER if type == "sa" or type == "wt": program, new_fns_dict, parameters = get_model_summer(dir_path, load=type == "wt") else: program, new_fns_dict, parameters = get_model_summer_pnn(dir_path) # output_type = ProgramOutputType.INTEGER data_loader_tr = interpreter._get_data_loader(io_examples_tr) data_loader_val = interpreter._get_data_loader(io_examples_val) data_loader_test = interpreter._get_data_loader(io_examples_test) new_fns_dict, max_accuracy_val, _, evaluations_np = interpreter.learn_neural_network_( program, output_type=output_type, new_fns_dict=new_fns_dict, trainable_parameters=list(parameters), data_loader_tr=data_loader_tr, data_loader_val=data_loader_val, data_loader_test=data_loader_test) max_accuracy_test = interpreter._get_accuracy(program, data_loader_test, output_type, new_fns_dict) print(max_accuracy_test) # num_examples = io_examples_tr[0].shape[0] if type(io_examples_tr) == tuple else io_examples_tr[0][0].shape[0] # np.save("{}/_{}__{}evaluations_np.npy".format(dir_path, prefix, num_examples), evaluations_np) if save: for key, value in new_fns_dict.items(): value.save("{}/Models/".format(dir_path)) return {"accuracy": max_accuracy_test}