def rpni_check_model_example(): import random from aalpy.SULs import MooreSUL from aalpy.learning_algs import run_RPNI from aalpy.oracles import StatePrefixEqOracle from aalpy.utils import generate_random_moore_machine, generate_random_dfa, load_automaton_from_file model = generate_random_dfa(num_states=5, alphabet=[1, 2, 3], num_accepting_states=2) model = generate_random_moore_machine(num_states=5, input_alphabet=[1, 2, 3], output_alphabet=['a', 'b']) input_al = model.get_input_alphabet() num_sequences = 1000 data = [] for _ in range(num_sequences): seq_len = random.randint(1, 20) random_seq = random.choices(input_al, k=seq_len) output = model.compute_output_seq(model.initial_state, random_seq)[-1] data.append((random_seq, output)) rpni_model = run_RPNI(data, automaton_type='moore', print_info=True) rpni_model.make_input_complete('sink_state') sul = MooreSUL(model) eq_oracle_2 = StatePrefixEqOracle(input_al, sul, walks_per_state=100) cex = eq_oracle_2.find_cex(rpni_model) if cex is None: print("Could not find a counterexample between the RPNI-model and the original model.") else: print('Counterexample found. Either RPNI data was incomplete, or there is a bug in RPNI algorithm :o ')
def extract_finite_state_transducer(rnn, input_alphabet, output_al, max_learning_rounds=10, formalism='mealy', print_level=2): assert formalism in ['mealy', 'moore'] outputs_2_ints = { integer: output for output, integer in tokenized_dict(output_al).items() } sul = RnnMealySUL(rnn, outputs_2_ints) eq_oracle = StatePrefixEqOracle(input_alphabet, sul, walks_per_state=150, walk_len=25) learned_automaton = run_Lstar(alphabet=input_alphabet, sul=sul, eq_oracle=eq_oracle, automaton_type=formalism, cache_and_non_det_check=False, max_learning_rounds=max_learning_rounds, suffix_closedness=True, print_level=print_level) return learned_automaton
def learn_python_class(): """ Learn a Mealy machine where inputs are methods and arguments of the class that serves as SUL. :return: Mealy machine """ # class mqtt = MockMqttExample input_al = [ FunctionDecorator(mqtt.connect), FunctionDecorator(mqtt.disconnect), FunctionDecorator(mqtt.subscribe, 'topic'), FunctionDecorator(mqtt.unsubscribe, 'topic'), FunctionDecorator(mqtt.publish, 'topic') ] sul = PyClassSUL(mqtt) eq_oracle = StatePrefixEqOracle(input_al, sul, walks_per_state=20, walk_len=20) mealy = run_Lstar(input_al, sul, eq_oracle=eq_oracle, automaton_type='mealy', cache_and_non_det_check=True) visualize_automaton(mealy)
def random_mealy_example(alphabet_size, number_of_states, output_size=8): """ Generate a random Mealy machine and learn it. :param alphabet_size: size of input alphabet :param number_of_states: number of states in generated Mealy machine :param output_size: size of the output :return: learned Mealy machine """ alphabet = [*range(0, alphabet_size)] random_mealy = generate_random_mealy_machine(number_of_states, alphabet, output_alphabet=list( range(output_size))) sul_mealy = MealySUL(random_mealy) random_walk_eq_oracle = RandomWalkEqOracle(alphabet, sul_mealy, 5000) state_origin_eq_oracle = StatePrefixEqOracle(alphabet, sul_mealy, walks_per_state=10, walk_len=15) learned_mealy = run_Lstar(alphabet, sul_mealy, random_walk_eq_oracle, automaton_type='mealy', cex_processing='longest_prefix') return learned_mealy
def train_and_extract_tomita(tomita_grammar, acc_stop=1., loss_stop=0.005, load=False): tomita_alphabet = ["0", "1"] if not load: rnn = train_RNN_on_tomita_grammar(tomita_grammar, acc_stop=acc_stop, loss_stop=loss_stop) else: rnn = train_RNN_on_tomita_grammar(tomita_grammar, train=False) rnn.load(f"RNN_Models/tomita_{tomita_grammar}.model") sul = RnnBinarySUL(rnn) alphabet = tomita_alphabet state_eq_oracle = StatePrefixEqOracle(alphabet, sul, walks_per_state=1000, walk_len=5) dfa = run_Lstar(alphabet=alphabet, sul=sul, eq_oracle=state_eq_oracle, automaton_type='dfa', cache_and_non_det_check=True) save_automaton_to_file(dfa, f'LearnedAutomata/learned_tomita{tomita_grammar}') visualize_automaton(dfa)
def random_moore_example(alphabet_size, number_of_states, output_size=8): """ Generate a random Moore machine and learn it. :param alphabet_size: size of input alphabet :param number_of_states: number of states in generated Mealy machine :param output_size: size of the output :return: learned Moore machine """ alphabet = [*range(0, alphabet_size)] random_moore = generate_random_moore_machine(number_of_states, alphabet, output_alphabet=list( range(output_size))) sul_mealy = MooreSUL(random_moore) state_origin_eq_oracle = StatePrefixEqOracle(alphabet, sul_mealy, walks_per_state=15, walk_len=20) learned_moore = run_Lstar(alphabet, sul_mealy, state_origin_eq_oracle, cex_processing='rs', closing_strategy='single', automaton_type='moore', cache_and_non_det_check=True) return learned_moore
def test_all_configuration_combinations(self): angluin_example = get_Angluin_dfa() alphabet = angluin_example.get_input_alphabet() automata_type = ['dfa', 'mealy', 'moore'] closing_strategies = ['shortest_first', 'longest_first', 'single'] cex_processing = [None, 'longest_prefix', 'rs'] suffix_closedness = [True, False] caching = [True, False] for automata in automata_type: for closing in closing_strategies: for cex in cex_processing: for suffix in suffix_closedness: for cache in caching: sul = DfaSUL(angluin_example) random_walk_eq_oracle = RandomWalkEqOracle(alphabet, sul, 5000, reset_after_cex=True) state_origin_eq_oracle = StatePrefixEqOracle(alphabet, sul, walks_per_state=10, walk_len=50) tran_cov_eq_oracle = TransitionFocusOracle(alphabet, sul, num_random_walks=200, walk_len=30, same_state_prob=0.3) w_method_eq_oracle = WMethodEqOracle(alphabet, sul, max_number_of_states=len(angluin_example.states)) random_W_method_eq_oracle = RandomWMethodEqOracle(alphabet, sul, walks_per_state=10, walk_len=50) bf_exploration_eq_oracle = BreadthFirstExplorationEqOracle(alphabet, sul, 3) random_word_eq_oracle = RandomWordEqOracle(alphabet, sul) cache_based_eq_oracle = CacheBasedEqOracle(alphabet, sul) kWayStateCoverageEqOracle = KWayStateCoverageEqOracle(alphabet, sul) oracles = [random_walk_eq_oracle, random_word_eq_oracle, random_W_method_eq_oracle, kWayStateCoverageEqOracle, cache_based_eq_oracle, bf_exploration_eq_oracle, tran_cov_eq_oracle, w_method_eq_oracle, state_origin_eq_oracle] if not cache: oracles.remove(cache_based_eq_oracle) for oracle in oracles: sul = DfaSUL(angluin_example) oracle.sul = sul learned_model = run_Lstar(alphabet, sul, oracle, automaton_type=automata, closing_strategy=closing, suffix_closedness=suffix, cache_and_non_det_check=cache, cex_processing=cex, print_level=0) is_eq = self.prove_equivalence(learned_model) if not is_eq: print(oracle, automata) assert False assert True
def random_dfa_example(alphabet_size, number_of_states, num_accepting_states=1): """ Generate a random DFA machine and learn it. :param alphabet_size: size of the input alphabet :param number_of_states: number of states in the generated DFA :param num_accepting_states: number of accepting states :return: DFA """ assert num_accepting_states <= number_of_states alphabet = list(string.ascii_letters[:26])[:alphabet_size] random_dfa = generate_random_dfa(number_of_states, alphabet, num_accepting_states) # visualize_automaton(random_dfa, path='correct') sul_dfa = DfaSUL(random_dfa) # examples of various equivalence oracles random_walk_eq_oracle = RandomWalkEqOracle(alphabet, sul_dfa, 5000) state_origin_eq_oracle = StatePrefixEqOracle(alphabet, sul_dfa, walks_per_state=10, walk_len=50) tran_cov_eq_oracle = TransitionFocusOracle(alphabet, sul_dfa, num_random_walks=200, walk_len=30, same_state_prob=0.3) w_method_eq_oracle = WMethodEqOracle(alphabet, sul_dfa, max_number_of_states=number_of_states) random_W_method_eq_oracle = RandomWMethodEqOracle(alphabet, sul_dfa, walks_per_state=10, walk_len=50) bf_exploration_eq_oracle = BreadthFirstExplorationEqOracle( alphabet, sul_dfa, 5) random_word_eq_oracle = RandomWordEqOracle(alphabet, sul_dfa) cache_based_eq_oracle = CacheBasedEqOracle(alphabet, sul_dfa) user_based_eq_oracle = UserInputEqOracle(alphabet, sul_dfa) kWayStateCoverageEqOracle = KWayStateCoverageEqOracle(alphabet, sul_dfa) learned_dfa = run_Lstar(alphabet, sul_dfa, random_walk_eq_oracle, automaton_type='dfa', cache_and_non_det_check=False, cex_processing='rs') # visualize_automaton(learned_dfa) return learned_dfa
def learn_date_validator(): from aalpy.base import SUL from aalpy.utils import visualize_automaton, DateValidator from aalpy.oracles import StatePrefixEqOracle from aalpy.learning_algs import run_Lstar class DateSUL(SUL): """ An example implementation of a system under learning that can be used to learn the behavior of the date validator. """ def __init__(self): super().__init__() # DateValidator is a black-box class used for date string verification # The format of the dates is %d/%m/%Y' # Its method is_date_accepted returns True if date is accepted, False otherwise self.dv = DateValidator() self.string = "" def pre(self): # reset the string used for testing self.string = "" pass def post(self): pass def step(self, letter): # add the input to the current string if letter is not None: self.string += str(letter) # test if the current sting is accepted return self.dv.is_date_accepted(self.string) # instantiate the SUL sul = DateSUL() # define the input alphabet alphabet = list(range(0, 9)) + ['/'] # define a equivalence oracle eq_oracle = StatePrefixEqOracle(alphabet, sul, walks_per_state=500, walk_len=15) # run the learning algorithm learned_model = run_Lstar(alphabet, sul, eq_oracle, automaton_type='dfa') # visualize the automaton visualize_automaton(learned_model)
def tomita_example(tomita_number): """ Pass a tomita function to this example and learn it. :param: function of the desired tomita grammar :rtype: Dfa :return DFA representing tomita grammar """ from aalpy.SULs import TomitaSUL from aalpy.learning_algs import run_Lstar from aalpy.oracles import StatePrefixEqOracle tomita_sul = TomitaSUL(tomita_number) alphabet = [0, 1] state_origin_eq_oracle = StatePrefixEqOracle(alphabet, tomita_sul, walks_per_state=50, walk_len=10) learned_dfa = run_Lstar(alphabet, tomita_sul, state_origin_eq_oracle, automaton_type='dfa', cache_and_non_det_check=True) return learned_dfa
def regex_example(regex, alphabet): """ Learn a regular expression. :param regex: regex to learn :param alphabet: alphabet of the regex :return: DFA representing the regex """ from aalpy.SULs import RegexSUL from aalpy.oracles import StatePrefixEqOracle from aalpy.learning_algs import run_Lstar regex_sul = RegexSUL(regex) eq_oracle = StatePrefixEqOracle(alphabet, regex_sul, walks_per_state=2000, walk_len=15) learned_regex = run_Lstar(alphabet, regex_sul, eq_oracle, automaton_type='dfa') return learned_regex
def test_eq_oracles(self): angluin_example = get_Angluin_dfa() alphabet = angluin_example.get_input_alphabet() automata_type = ['dfa', 'mealy', 'moore'] for automata in automata_type: sul = DfaSUL(angluin_example) random_walk_eq_oracle = RandomWalkEqOracle(alphabet, sul, 5000, reset_after_cex=True) state_origin_eq_oracle = StatePrefixEqOracle(alphabet, sul, walks_per_state=10, walk_len=50) tran_cov_eq_oracle = TransitionFocusOracle(alphabet, sul, num_random_walks=200, walk_len=30, same_state_prob=0.3) w_method_eq_oracle = WMethodEqOracle(alphabet, sul, max_number_of_states=len(angluin_example.states)) random_W_method_eq_oracle = RandomWMethodEqOracle(alphabet, sul, walks_per_state=10, walk_len=50) bf_exploration_eq_oracle = BreadthFirstExplorationEqOracle(alphabet, sul, 3) random_word_eq_oracle = RandomWordEqOracle(alphabet, sul) cache_based_eq_oracle = CacheBasedEqOracle(alphabet, sul) kWayStateCoverageEqOracle = KWayStateCoverageEqOracle(alphabet, sul) oracles = [random_walk_eq_oracle, random_word_eq_oracle, random_W_method_eq_oracle, w_method_eq_oracle, kWayStateCoverageEqOracle, cache_based_eq_oracle, bf_exploration_eq_oracle, tran_cov_eq_oracle, state_origin_eq_oracle] for oracle in oracles: sul = DfaSUL(angluin_example) oracle.sul = sul learned_model = run_Lstar(alphabet, sul, oracle, automaton_type=automata, cache_and_non_det_check=True, cex_processing=None, print_level=0) is_eq = self.prove_equivalence(learned_model) if not is_eq: print(oracle, automata) assert False assert True
def falsify_refinement_based_model(): """ Show how extensive coverage-based testing can be used to falsify model returned from refinement-based extraction approach. """ rnn, alphabet, train_set = train_or_load_rnn('bp_1', num_layers=2, hidden_dim=50, rnn_class=GRUNetwork, train=False) # initial examples for Weiss et Al all_words = sorted(list(train_set.keys()), key=lambda x: len(x)) pos = next((w for w in all_words if rnn.classify_word(w) is True), None) neg = next((w for w in all_words if rnn.classify_word(w) is False), None) starting_examples = [w for w in [pos, neg] if None is not w] # Extract Automaton Using White-Box eq. query rnn.renew() start_white_box = time.time() dfa_weiss = extract(rnn, time_limit=500, initial_split_depth=10, starting_examples=starting_examples) time_white_box = time.time() - start_white_box # Make sure that internal states are back to initial rnn.renew() white_box_hyp = Weiss_to_AALpy_DFA_format(dfa_weiss) sul = RNN_BinarySUL_for_Weiss_Framework(rnn) eq_oracle = TransitionFocusOracle(alphabet, sul, num_random_walks=1000, walk_len=20) eq_oracle = StatePrefixEqOracle(alphabet, sul, walks_per_state=1500, walk_len=20) cex_set = set() for _ in range(10): start_time = time.time() cex = eq_oracle.find_cex(white_box_hyp) if not cex or tuple(cex) in cex_set: continue cex_set.add(tuple(cex)) end_time = time.time() - start_time print(round(end_time, 2), "".join(cex))
from TrainAndExtract import train_RNN_on_tomita_grammar, train_and_extract_bp, train_RNN_and_extract_FSM # learn and extract tomita 3 grammar. # same can be achieved with train_and_extract_tomita function rnn = train_RNN_on_tomita_grammar(tomita_grammar=3, acc_stop=1., loss_stop=0.005, train=True) tomita_alphabet = ["0", "1"] sul = RnnBinarySUL(rnn) alphabet = tomita_alphabet state_eq_oracle = StatePrefixEqOracle(alphabet, sul, walks_per_state=200, walk_len=6) dfa = run_Lstar(alphabet=alphabet, sul=sul, eq_oracle=state_eq_oracle, automaton_type='dfa', cache_and_non_det_check=True) save_automaton_to_file(dfa, f'RNN_Models/tomita{3}') visualize_automaton(dfa) # train and extract balanced parentheses bp_model = train_and_extract_bp( path='TrainingDataAndAutomata/balanced()_2.txt', load=False) print("Print extracted model")
benchmarks = benchmarks[:10] caching_opt = [True, False] closing_options = ['shortest_first', 'longest_first', 'single'] suffix_processing = ['all', 'single'] counter_example_processing = ['rs', 'longest_prefix', None] e_closedness = ['prefix', 'suffix'] for b in benchmarks: automaton = load_automaton_from_file(f'{exp}/{b}', automaton_type='dfa') input_al = automaton.get_input_alphabet() sul_dfa = sul(automaton) state_origin_eq_oracle = StatePrefixEqOracle(input_al, sul_dfa, walks_per_state=5, walk_len=25) learned_dfa, data = run_Lstar(input_al, sul_dfa, state_origin_eq_oracle, automaton_type='dfa', cache_and_non_det_check=False, cex_processing='rs', return_data=True, print_level=0) run_times.append(data['total_time']) print(run_times) print(mean(run_times))
def conformance_check_2_RNNs(experiment='coffee'): """ Show how learning based testing can find differences between 2 trained RNNs. RNNs are have the same configuration, but it can be different. :param experiment: either coffee of mqtt :return: cases of non-conformance between trained RNNs """ if experiment == 'coffee': mm, exp = get_coffee_machine(), experiment else: mm, exp = get_mqtt_mealy(), experiment input_al = mm.get_input_alphabet() output_al = { output for state in mm.states for output in state.output_fun.values() } train_seq, train_labels = generate_data_from_automaton(mm, input_al, num_examples=10000, lens=(2, 5, 8, 10)) training_data = (train_seq, train_labels) rnn_1 = train_RNN_on_mealy_data(mm, data=training_data, ex_name=f'{exp}_1') rnn_2 = train_RNN_on_mealy_data(mm, data=training_data, ex_name=f'{exp}_2') learned_automaton_1 = extract_finite_state_transducer( rnn_1, input_al, output_al, max_learning_rounds=25) learned_automaton_2 = extract_finite_state_transducer( rnn_2, input_al, output_al, max_learning_rounds=25) sul = MealySUL(learned_automaton_1) sul2 = MealySUL(learned_automaton_2) eq_oracle = LongCexEqOracle(input_al, sul, num_walks=500, min_walk_len=1, max_walk_len=30, reset_after_cex=True) eq_oracle = StatePrefixEqOracle(input_al, sul, walks_per_state=100, walk_len=20) cex_set = set() for i in range(200): cex = eq_oracle.find_cex(learned_automaton_2) if cex: if tuple(cex) not in cex_set: print( '--------------------------------------------------------------------------' ) print('Case of Non-Conformance between Automata: ', cex) print('Model 1 : ', sul.query(cex)) print('Model 2 : ', sul2.query(cex)) cex_set.add(tuple(cex)) return cex_set
def retraining_based_on_non_conformance(ground_truth_model=get_coffee_machine( ), num_rnns=2, num_training_samples=5000, samples_lens=(3, 6, 9, 12)): """ :param ground_truth_model: correct model used for labeling cases of non-conformance :param num_rnns: number of RNN to be trained and learned :param num_training_samples: initial number of training samples in the training data set :param samples_lens: lengths of initial training data set samples :return: one RNN obtained after active retraining """ assert num_rnns >= 2 and num_training_samples > 0 input_al = ground_truth_model.get_input_alphabet() if isinstance(ground_truth_model, MealyMachine): output_al = { output for state in ground_truth_model.states for output in state.output_fun.values() } else: output_al = [False, True] # Create initial training data train_seq, train_labels = generate_data_from_automaton( ground_truth_model, input_al, num_examples=num_training_samples, lens=samples_lens) # While the input-output behaviour of all trained neural networks is different iteration = 0 while True: iteration += 1 print(f'Learning/extraction round: {iteration}') trained_networks = [] x_train, y_train, x_test, y_test = split_train_validation(train_seq, train_labels, 0.8, uniform=True) # Train all neural networks with same parameters for i in range(num_rnns): rnn = RNNClassifier(input_al, output_dim=len(output_al), num_layers=2, hidden_dim=40, x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test, batch_size=32, nn_type='GRU') print(f'Starting training of RNN {i}') rnn.train(epochs=150, stop_acc=1.0, stop_epochs=3, verbose=False) trained_networks.append(rnn) learned_automatons = [] # Extract automaton for each neural network for i, rnn in enumerate(trained_networks): print(f'Starting extraction of the automaton from RNN {i}') learned_automaton = extract_finite_state_transducer( rnn, input_al, output_al, max_learning_rounds=8, print_level=0) learned_automatons.append(learned_automaton) learned_automatons.sort(key=lambda x: len(x.states), reverse=True) # Select one automaton as a basis for conformance-checking. You can also do conformance checking with all pairs # of learned automata. base_sul = MealySUL(learned_automatons[0]) # Select the eq. oracle eq_oracle = LongCexEqOracle(input_al, base_sul, num_walks=500, min_walk_len=1, max_walk_len=30, reset_after_cex=True) eq_oracle = StatePrefixEqOracle(input_al, base_sul, walks_per_state=100, walk_len=50) cex_set = set() # Try to find cases of non-conformance between learned automatons. for la in learned_automatons[1:]: for i in range(200): cex = eq_oracle.find_cex(la) if cex: cex_set.add(tuple(cex)) # If there were no counterexamples between any learned automata, we end the procedure if not cex_set: for i, la in enumerate(learned_automatons): print(f'Size of automata {i}: {len(la.states)}') print(learned_automatons[-1]) print('No counterexamples between extracted automata found.') break # Ask ground truth model for correct labels new_x, new_y = label_sequences_with_correct_model( ground_truth_model, cex_set) print(f'Adding {len(cex_set)} new examples to training data.') new_x = tokenize(new_x, input_al) new_y = tokenize(new_y, output_al) train_seq.extend(new_x) train_labels.extend(new_y) print(f'Size of training data: {len(train_seq)}')
def retraining_based_on_ground_truth(ground_truth_model=get_coffee_machine(), num_train_samples=5000, lens=(3, 8, 10, 12, 15)): """ :param ground_truth_model: correct model used for data generation and confromance checking :param num_train_samples: num of training samples for the initial data generation :param lens: lengths of counterexample :return: trained RNN that conforms to the ground truth model """ input_al = ground_truth_model.get_input_alphabet() if isinstance(ground_truth_model, MealyMachine): output_al = { output for state in ground_truth_model.states for output in state.output_fun.values() } else: output_al = [False, True] # Create initial training data train_seq, train_labels = generate_data_from_automaton( ground_truth_model, input_al, num_examples=num_train_samples, lens=lens) # While the input-output behaviour of all trained neural networks is different iteration = 0 while True: iteration += 1 # split dataset into training and verification x_train, y_train, x_test, y_test = split_train_validation(train_seq, train_labels, 0.8, uniform=True) # Train all neural networks with same parameters, this can be configured to train with different parameters rnn = RNNClassifier(input_al, output_dim=len(output_al), num_layers=2, hidden_dim=40, x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test, batch_size=32, nn_type='GRU') print( f"Starting training of the neural network for the {iteration} time" ) # Train the NN rnn.train(epochs=150, stop_acc=1.0, stop_epochs=3, verbose=0) # encode outputs outputs_2_ints = { integer: output for output, integer in tokenized_dict(output_al).items() } # use RNN as SUL sul = RnnMealySUL(rnn, outputs_2_ints) # Select the eq. oracle eq_oracle = LongCexEqOracle(input_al, sul, num_walks=500, min_walk_len=1, max_walk_len=30, reset_after_cex=True) eq_oracle = StatePrefixEqOracle(input_al, sul, walks_per_state=200, walk_len=20) cex_set = set() # Try to find cases of non-conformance between learned automatons. print('Searching for counterexample.') for i in range(200): # Conformance check ground truth model and trained RNN # Alternatively, one can extract automaton from RNN and then model check against GT cex = eq_oracle.find_cex(ground_truth_model) if cex: cex_set.add(tuple(cex)) # if there were no counterexamples between any learned automata, we end the procedure if not cex_set: print( 'No counterexamples found between extracted automaton and neural network.' ) # Extract automaton from rnn and print it final_model = run_Lstar(input_al, sul, eq_oracle, automaton_type='mealy', max_learning_rounds=15) print(final_model) return rnn # Ask ground truth model for correct labels new_x, new_y = label_sequences_with_correct_model( ground_truth_model, cex_set) print(f'Adding {len(cex_set)} new examples to training data.') new_x = tokenize(new_x, input_al) new_y = tokenize(new_y, output_al) train_seq.extend(new_x) train_labels.extend(new_y) print(f'Size of training data: {len(train_seq)}')
def learn_with_mapper(): train_seq, train_labels, input_al, output_al = generate_concrete_data_MQTT( num_examples=300000, num_rand_topics=2, lens=(1, 2, 3, 5, 8, 10, 12), uniform_concretion=True) x_train, y_train, x_test, y_test = split_train_validation(train_seq, train_labels, 0.8, uniform=True) # train_seq, train_labels = generate_data_based_on_characterization_set(mealy_machine) # x_train, y_train, x_test, y_test = split_train_validation(train_seq, train_labels, 0.8, uniform=True) rnn = RNNClassifier(input_al, output_dim=len(output_al), num_layers=5, hidden_dim=40, x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test, batch_size=32, nn_type='GRU') load = True ex_name = 'abstracted_mqtt' if not load: rnn.train(epochs=200, stop_acc=1.0, stop_epochs=3) rnn.save(f'RNN_Models/{ex_name}.rnn') with open(f'RNN_Models/{ex_name}.pickle', 'wb') as handle: pickle.dump((input_al, output_al), handle, protocol=pickle.HIGHEST_PROTOCOL) else: rnn.load(f'RNN_Models/{ex_name}.rnn') with open(f'RNN_Models/{ex_name}.pickle', 'rb') as handle: inp_out_tuple = pickle.load(handle) input_al, output_al = inp_out_tuple[0], inp_out_tuple[1] rnn.token_dict = dict((c, i) for i, c in enumerate(input_al)) sul = Abstract_Mapper_MQTT_RNN_SUL(rnn, input_al, output_al) abstract_inputs = sul.abstract_inputs eq_oracle = StatePrefixEqOracle(abstract_inputs, sul, walks_per_state=100, walk_len=20) model = run_Lstar(abstract_inputs, sul, eq_oracle, automaton_type='mealy', cache_and_non_det_check=True, suffix_closedness=False) visualize_automaton(model) return model