def _generator(self, split_id, max_length_formula, max_length_trace, prepend_start_token, tree_pos_enc): target_file = path.join(self.dataset_dir, self.targets[split_id] + '.txt') with tf.io.gfile.GFile(target_file, 'r') as file: # expect formula\ntrace\n format for line_in in file: if line_in == '\n': return line_out = next(file) # get second line if max_length_formula >= 0 and len( line_in) > max_length_formula: continue if max_length_trace >= 0 and len(line_out) > max_length_trace: continue formula = ltl_parser.ltl_formula(line_in.strip(), 'network-polish') encoded_in = self.ltl_vocab.encode( formula.to_str('network-polish', spacing='all ops').split(' ')) encoded_out = self.trace_vocab.encode( line_out.strip(), prepend_start_token=prepend_start_token) if tree_pos_enc: position_list = formula.binary_position_list( format='lbt', add_first=True) # pad to max length max_length = max([len(l) for l in position_list]) padded_position_list = [ l + [0] * (max_length - len(l)) for l in position_list ] yield (tf.constant(encoded_in), tf.constant(padded_position_list, dtype=tf.float32), tf.constant(encoded_out)) else: yield (tf.constant(encoded_in), tf.constant(encoded_out))
def main(): args = parse_args() aps = list(map(chr, range(97, 97 + args.num_aps))) seed = 42 formula_generator = spot.randltl( aps, seed=seed, tree_size=(1, args.max_size + 2), ltl_priorities= 'false=1,true=1,not=1,F=0,G=0,X=0,equiv=0,implies=0,xor=0,R=0,U=0,W=0,M=0,and=1,or=1', simplify=0) gate = DistributionGate('formula size', 'uniform', (1, args.max_size), args.num_examples, start_calc_from=12, alpha=args.alpha) worker = utils.PersistentWorker() worker_calls = 0 samples = [] total_samples = 0 while total_samples < args.num_examples and not gate.full(): formula_str_spot = next(formula_generator).to_str('lbt') formula_obj = ltl_parser.ltl_formula(formula_str_spot, 'lbt') polish_pyaiger = spot_to_pyaiger( formula_obj.to_str('network-polish', spacing='all ops').split(' ')) if not gate.gate(formula_obj): continue if worker_calls >= 10000: worker.terminate() worker_calls = 0 finished, min_model = worker.call(generate_model, (polish_pyaiger, None), 60) worker_calls += 1 assert finished if min_model is None: continue # no unsat gate.update(formula_obj) assignments_spot = ''.join(pyaiger_to_spot(min_model.split(' '))) formula_spot = formula_obj.to_str('network-polish') samples.append((formula_spot, assignments_spot)) total_samples += 1 if total_samples % 10000 == 0: print(f'{total_samples/args.num_examples*100:5.1f}% complete') sys.stdout.flush() try: split_and_write(samples, args, seed, gate) finally: worker.terminate()
def calculate_accuracy(formulas_file, traces_file, targets_file, log_file, sat_prob_file, polish, sem_desp_syn, per_size, validator, log_level, **kwargs): with nice_open(formulas_file, 'r') as formulas, nice_open( traces_file, 'r') as traces, nice_open(targets_file, 'r') as targets, nice_open( log_file, 'w') as log, nice_open(sat_prob_file, 'w') as sat_prob: line_num = 0 tictoc = TicToc() if per_size: res = { 'syntactically correct': {}, 'only semantically correct': {}, 'incorrect': {}, 'invalid': {}, 'unknown': {} } def increment(key, formula_obj): size = formula_obj.size() if size in res[key]: res[key][size] += 1 else: res[key][size] = 1 else: res = { 'syntactically correct': 0, 'only semantically correct': 0, 'incorrect': 0, 'invalid': 0, 'unknown': 0 } def increment(key, formula_obj): res[key] += 1 if validator == 'spot' or validator == 'both': import spot for formula_str, trace_str in zip(formulas, traces): formula_str, trace_str = formula_str.strip(), trace_str.strip() line_num += 1 target_str = next(targets).strip() if targets else None if target_str == '-': # no trace target_str = None formula_format = 'network-' + ('polish' if polish else 'infix') formula_obj = ltl_parser.ltl_formula(formula_str, format=formula_format) # trace valid syntactically? try: trace_obj = ltl_parser.ltl_trace(trace_str, format=formula_format) except ltl_parser.ParseError as e: increment('invalid', formula_obj) if log and log_level >= 1: log.write( "INVALID {:d}\ninput (raw): {}\noutput (raw): {}\ntarget (raw): {}\nerror: {}\n\n" .format(line_num, formula_str, trace_str, target_str, e)) continue # trace equal to target (if available)? if target_str: # target available target_obj = ltl_parser.ltl_trace(target_str, format=formula_format) if trace_obj.equal_to(target_obj, extended_eq=True): increment('syntactically correct', formula_obj) syntactically_correct = True if log and log_level >= 4: log.write( "SYNTACTICALLY CORRECT {:d}\ninput : {}\noutput: {}\n\n" .format(line_num, formula_obj.to_str('spot'), trace_obj.to_str('spot'))) if not sem_desp_syn: continue else: syntactically_correct = False else: target_obj = None syntactically_correct = None # sat problem sat_obj = encode_for_satisfiability(trace_obj, formula_obj) sat_formula = sat_obj.to_str('spot', spacing='all ops', full_parens=True) if sat_prob: sat_formula_conv = sat_formula.replace('1', 'True').replace( '0', 'False').replace('!', '~') sat_prob.write(sat_formula_conv) # aalta trace check if validator == 'aalta' or validator == 'both': tictoc.tic() try: aalta_result = aalta_wrapper.sat(sat_formula, timeout=20) aalta_holds = not aalta_result if aalta_result is not None else None except RuntimeError as e: aalta_holds = None tictoc.toc('aalta check') else: aalta_holds = None # spot trace check if validator == 'spot' or validator == 'both': formula_spot = spot.formula(formula_obj.to_str('spot')) trace_spot = spot.parse_word(trace_obj.to_str('spot')) tictoc.tic() formula_automaton = formula_spot.translate() trace_automaton = trace_spot.as_automaton() tictoc.toc('spot translate') tictoc.tic() try: spot_holds = spot.contains( formula_automaton, trace_automaton ) # spot.contains checks whether language of its right argument is included in language of its left argument except RuntimeError: spot_holds = None tictoc.toc('spot contains') else: spot_holds = None # compare, evaluate trace checks trace_holds = aalta_holds if aalta_holds is not None else spot_holds # if both, same, else the one that is there or both None if validator == 'both' and aalta_holds != spot_holds: print('Formula ', formula_obj.to_str('spot')) print('Trace ', trace_obj.to_str('spot')) print('Sat form', sat_formula) print('MISMATCH aalta: {} -- spot: {}\n'.format( aalta_holds, spot_holds)) trace_holds = spot_holds # trust spot more if trace_holds is None: if log: log.write( "UNKNOWN {:d}\ninput : {}\noutput: {}\ntarget: {}\n\n". format( line_num, formula_obj.to_str('spot'), trace_obj.to_str('spot'), target_obj.to_str('spot') if target_obj else None)) increment('unknown', formula_obj) elif trace_holds: if not sem_desp_syn or (not syntactically_correct): if log and log_level >= 3: log.write( "SEMANTICALLY CORRECT {:d}\ninput : {}\noutput: {}\ntarget: {}\n\n" .format( line_num, formula_obj.to_str('spot'), trace_obj.to_str('spot'), target_obj.to_str('spot') if target_obj else None)) increment('only semantically correct', formula_obj) else: # dosen't hold increment('incorrect', formula_obj) if log and log_level >= 2: log.write( "INCORRECT {:d}\ninput : {}\noutput: {}\ntarget: {}\n\n" .format( line_num, formula_obj.to_str('spot'), trace_obj.to_str('spot'), target_obj.to_str('spot') if target_obj else None)) if sem_desp_syn and syntactically_correct: raise RuntimeError( 'Trace is said to be syntactically correct, but does not fulfil formula!' ) tictoc.histogram(show=False) # evaluation if per_size: res = per_size_analysis(res, **kwargs) res['total'] = line_num res['correct'] = res['syntactically correct'] + res[ 'only semantically correct'] assert res['total'] == res['correct'] + res['incorrect'] + res[ 'invalid'] + res['unknown'] res_str = "Correct: {:f}%, {correct:d} out of {total:d}\nSyntactically correct: {:f}%, {syntactically correct:d} out of {total:d}\n"\ "Semantically correct, but not syntactically: {:f}%, {only semantically correct:d} out of {total:d}\n"\ "Incorrect: {:f}%, {incorrect:d} out of {total:d}\nInvalid: {:f}%, {invalid:d} out of {total:d}\n"\ "Unknown: {unknown:d} out of {total:d}\n"\ "".format(res['correct'] / res['total'] * 100, res['syntactically correct'] / res['total'] * 100, res['only semantically correct'] / res['total'] * 100, res['incorrect'] / res['total'] * 100, res['invalid'] / res['total'] * 100, **res) if log and not (log is sys.stdout): log.write(res_str) return res, res_str
def generate_samples(num_aps, num_formulas, tree_size, seed, polish, simplify, train_frac, val_frac, unsat_frac, trace_generator, timeout, require_trace, alpha, **kwargs): if num_aps > 26: raise ValueError("Cannot generate more than 26 APs") aps = list(map(chr, range(97, 97 + num_aps))) if isinstance(tree_size, int): tree_size = (1, tree_size) formula_generator = spot.randltl( aps, seed=seed, tree_size=tree_size, ltl_priorities= 'false=1,true=1,not=1,F=0,G=0,X=1,equiv=0,implies=0,xor=0,R=0,U=1,W=0,M=0,and=1,or=0', simplify=0) tictoc = utils.TicToc() dist_gate = DistributionGate('formula size', 'uniform', tree_size, num_formulas, start_calc_from=10, alpha=alpha) global SPOT_WORKER SPOT_WORKER = utils.PersistentWorker() # generate samples print('Generating samples...') sat_only = unsat_frac == 0.0 samples = [] sat_samples = 0 unsat_samples = 0 total_samples = 0 while total_samples < num_formulas and not dist_gate.full(): tictoc.tic() formula_spot = next(formula_generator) tictoc.toc('formula generation') formula_str = formula_spot.to_str() formula_obj = ltl_parser.ltl_formula(formula_str, 'spot') if not dist_gate.gate(formula_obj): # formula doesn't fit distribution continue # add some spaces and parenthesis to be safe for aalta formula_spaced = formula_obj.to_str('spot', spacing='all ops', full_parens=True) tictoc.tic() is_sat, trace_str = get_sat_and_trace(formula_spaced, trace_generator, simplify, timeout) tictoc.toc('trace generation') if is_sat is None: # due to timeout print('Trace generation timed out ({:d}s) for formula {}'.format( int(timeout), formula_obj.to_str('spot'))) if require_trace: continue else: # no trace required trace_str = '-' dist_gate.update(formula_obj) elif not is_sat and sat_only: continue elif not is_sat and not sat_only: if unsat_samples >= unsat_frac * num_formulas: continue else: # more unsat samples needed trace_str = '{0}' dist_gate.update(formula_obj) unsat_samples += 1 else: # is_sat if '0' in trace_str: print( 'Bug in spot! (trace containing 0):\nFormula: {}\nTrace: {}\n' .format(formula_obj.to_str('spot'), trace_str)) continue assert unsat_samples < unsat_frac * \ num_formulas or not ('0' in trace_str and not sat_only) if sat_samples >= (1 - unsat_frac) * num_formulas: continue else: # more sat samples needed trace_str = ltl_parser.ltl_trace( trace_str, 'spot').to_str('network-' + ('polish' if polish else 'infix')) dist_gate.update(formula_obj) sat_samples += 1 formula_str = formula_obj.to_str('network-' + ('polish' if polish else 'infix')) samples.append((formula_str, trace_str)) if total_samples % (num_formulas // 10) == 0 and total_samples > 0: print("%d/%d" % (total_samples, num_formulas)) total_samples += 1 sys.stdout.flush() # dist_gate.histogram(show=False, save_to='dist_nf{}_ts{:d}-{:d}.png'.format(utils.abbrev_count(num_formulas), tree_size[0], tree_size[1])) # For distribution analysis # tictoc.histogram(show=False, save_to='timing_nf{}_ts{:d}-{:d}.png'.format(utils.abbrev_count(num_formulas), tree_size[0], tree_size[1])) # For timing analysis print('Generated {:d} samples, {:d} requested'.format( total_samples, num_formulas)) SPOT_WORKER.terminate() # shuffle and split samples random.Random(seed).shuffle(samples) res = {} res['train'] = samples[0:int(train_frac * total_samples)] res['val'] = samples[int(train_frac * total_samples):int((train_frac + val_frac) * total_samples)] res['test'] = samples[int((train_frac + val_frac) * total_samples):] return res
def test_and_analyze_sat(pred_model, dataset, in_vocab, out_vocab, log_name, **kwargs): from deepltl.data.sat_generator import spot_to_pyaiger, is_model logdir = path.join(kwargs['job_dir'], kwargs['run_name']) tf.io.gfile.makedirs(logdir) with open(path.join(logdir, log_name), 'w') as log_file: res = {'invalid': 0, 'incorrect': 0, 'syn_correct': 0, 'sem_correct': 0} for x in dataset: if kwargs['tree_pos_enc']: data, pe, label_ = x prediction, _ = pred_model([data, pe], training=False) else: data, label_ = x prediction, _ = pred_model(data, training=False) for i in range(prediction.shape[0]): formula = in_vocab.decode(list(data[i, :]), as_list=True) pred = out_vocab.decode(list(prediction[i, :]), as_list=True) label = out_vocab.decode(list(label_[i, :]), as_list=True) formula_obj = ltl_parser.ltl_formula(''.join(formula), 'network-polish') formula_str = formula_obj.to_str('spot') _, pretty_label_ass = get_ass(label) try: _, pretty_ass = get_ass(pred) except ValueError as e: res['invalid'] += 1 msg = f"INVALID ({str(e)})\nFormula: {formula_str}\nPred: {' '.join(pred)}\nLabel: {pretty_label_ass}\n" log_file.write(msg) continue if pred == label: res['syn_correct'] += 1 msg = f"SYNTACTICALLY CORRECT\nFormula: {formula_str}\nPred: {pretty_ass}\nLabel: {pretty_label_ass}\n" # log_file.write(msg) continue # semantic checking formula_pyaiger = spot_to_pyaiger(formula) ass_pyaiger = spot_to_pyaiger(pred) pyaiger_ass_dict, _ = get_ass(ass_pyaiger) # print(f'f: {formula_pyaiger}, ass: {pyaiger_ass_dict}') try: holds = is_model(formula_pyaiger, pyaiger_ass_dict) except KeyError as e: res['incorrect'] += 1 msg = f"INCORRECT (var {str(e)} not in formula)\nFormula: {formula_str}\nPred: {pretty_ass}\nLabel: {pretty_label_ass}\n" log_file.write(msg) continue if holds: res['sem_correct'] += 1 msg = f"SEMANTICALLY CORRECT\nFormula: {formula_str}\nPred: {pretty_ass}\nLabel: {pretty_label_ass}\n" log_file.write(msg) else: res['incorrect'] += 1 msg = f"INCORRECT\nFormula: {formula_str}\nPred: {pretty_ass}\nLabel: {pretty_label_ass}\n" log_file.write(msg) total = sum(res.values()) correct = res['syn_correct'] + res['sem_correct'] msg = (f"Correct: {correct/total*100:.1f}%, {correct} out of {total}\nSyntactically correct: {res['syn_correct']/total*100:.1f}%\nSemantically correct: {res['sem_correct']/total*100:.1f}%\n" f"Incorrect: {res['incorrect']/total*100:.1f}%\nInvalid: {res['invalid']/total*100:.1f}%\n") log_file.write(msg) print(msg, end='')