Пример #1
0
 def _generator(self, split_id, max_length_formula, max_length_trace,
                prepend_start_token, tree_pos_enc):
     target_file = path.join(self.dataset_dir,
                             self.targets[split_id] + '.txt')
     with tf.io.gfile.GFile(target_file,
                            'r') as file:  # expect formula\ntrace\n format
         for line_in in file:
             if line_in == '\n':
                 return
             line_out = next(file)  # get second line
             if max_length_formula >= 0 and len(
                     line_in) > max_length_formula:
                 continue
             if max_length_trace >= 0 and len(line_out) > max_length_trace:
                 continue
             formula = ltl_parser.ltl_formula(line_in.strip(),
                                              'network-polish')
             encoded_in = self.ltl_vocab.encode(
                 formula.to_str('network-polish',
                                spacing='all ops').split(' '))
             encoded_out = self.trace_vocab.encode(
                 line_out.strip(), prepend_start_token=prepend_start_token)
             if tree_pos_enc:
                 position_list = formula.binary_position_list(
                     format='lbt', add_first=True)
                 # pad to max length
                 max_length = max([len(l) for l in position_list])
                 padded_position_list = [
                     l + [0] * (max_length - len(l)) for l in position_list
                 ]
                 yield (tf.constant(encoded_in),
                        tf.constant(padded_position_list, dtype=tf.float32),
                        tf.constant(encoded_out))
             else:
                 yield (tf.constant(encoded_in), tf.constant(encoded_out))
Пример #2
0
def main():
    args = parse_args()
    aps = list(map(chr, range(97, 97 + args.num_aps)))
    seed = 42
    formula_generator = spot.randltl(
        aps,
        seed=seed,
        tree_size=(1, args.max_size + 2),
        ltl_priorities=
        'false=1,true=1,not=1,F=0,G=0,X=0,equiv=0,implies=0,xor=0,R=0,U=0,W=0,M=0,and=1,or=1',
        simplify=0)

    gate = DistributionGate('formula size',
                            'uniform', (1, args.max_size),
                            args.num_examples,
                            start_calc_from=12,
                            alpha=args.alpha)
    worker = utils.PersistentWorker()
    worker_calls = 0
    samples = []
    total_samples = 0
    while total_samples < args.num_examples and not gate.full():
        formula_str_spot = next(formula_generator).to_str('lbt')
        formula_obj = ltl_parser.ltl_formula(formula_str_spot, 'lbt')
        polish_pyaiger = spot_to_pyaiger(
            formula_obj.to_str('network-polish', spacing='all ops').split(' '))
        if not gate.gate(formula_obj):
            continue
        if worker_calls >= 10000:
            worker.terminate()
            worker_calls = 0
        finished, min_model = worker.call(generate_model,
                                          (polish_pyaiger, None), 60)
        worker_calls += 1
        assert finished
        if min_model is None:
            continue  # no unsat
        gate.update(formula_obj)

        assignments_spot = ''.join(pyaiger_to_spot(min_model.split(' ')))
        formula_spot = formula_obj.to_str('network-polish')
        samples.append((formula_spot, assignments_spot))
        total_samples += 1
        if total_samples % 10000 == 0:
            print(f'{total_samples/args.num_examples*100:5.1f}% complete')
            sys.stdout.flush()
    try:
        split_and_write(samples, args, seed, gate)
    finally:
        worker.terminate()
Пример #3
0
def calculate_accuracy(formulas_file, traces_file, targets_file, log_file,
                       sat_prob_file, polish, sem_desp_syn, per_size,
                       validator, log_level, **kwargs):
    with nice_open(formulas_file, 'r') as formulas, nice_open(
            traces_file,
            'r') as traces, nice_open(targets_file, 'r') as targets, nice_open(
                log_file, 'w') as log, nice_open(sat_prob_file,
                                                 'w') as sat_prob:
        line_num = 0
        tictoc = TicToc()
        if per_size:
            res = {
                'syntactically correct': {},
                'only semantically correct': {},
                'incorrect': {},
                'invalid': {},
                'unknown': {}
            }

            def increment(key, formula_obj):
                size = formula_obj.size()
                if size in res[key]:
                    res[key][size] += 1
                else:
                    res[key][size] = 1
        else:
            res = {
                'syntactically correct': 0,
                'only semantically correct': 0,
                'incorrect': 0,
                'invalid': 0,
                'unknown': 0
            }

            def increment(key, formula_obj):
                res[key] += 1

        if validator == 'spot' or validator == 'both':
            import spot

        for formula_str, trace_str in zip(formulas, traces):
            formula_str, trace_str = formula_str.strip(), trace_str.strip()
            line_num += 1
            target_str = next(targets).strip() if targets else None
            if target_str == '-':  # no trace
                target_str = None
            formula_format = 'network-' + ('polish' if polish else 'infix')
            formula_obj = ltl_parser.ltl_formula(formula_str,
                                                 format=formula_format)

            # trace valid syntactically?
            try:
                trace_obj = ltl_parser.ltl_trace(trace_str,
                                                 format=formula_format)
            except ltl_parser.ParseError as e:
                increment('invalid', formula_obj)
                if log and log_level >= 1:
                    log.write(
                        "INVALID {:d}\ninput  (raw): {}\noutput (raw): {}\ntarget (raw): {}\nerror: {}\n\n"
                        .format(line_num, formula_str, trace_str, target_str,
                                e))
                continue

            # trace equal to target (if available)?
            if target_str:  # target available
                target_obj = ltl_parser.ltl_trace(target_str,
                                                  format=formula_format)
                if trace_obj.equal_to(target_obj, extended_eq=True):
                    increment('syntactically correct', formula_obj)
                    syntactically_correct = True
                    if log and log_level >= 4:
                        log.write(
                            "SYNTACTICALLY CORRECT {:d}\ninput : {}\noutput: {}\n\n"
                            .format(line_num, formula_obj.to_str('spot'),
                                    trace_obj.to_str('spot')))
                    if not sem_desp_syn:
                        continue
                else:
                    syntactically_correct = False
            else:
                target_obj = None
                syntactically_correct = None

            # sat problem
            sat_obj = encode_for_satisfiability(trace_obj, formula_obj)
            sat_formula = sat_obj.to_str('spot',
                                         spacing='all ops',
                                         full_parens=True)
            if sat_prob:
                sat_formula_conv = sat_formula.replace('1', 'True').replace(
                    '0', 'False').replace('!', '~')
                sat_prob.write(sat_formula_conv)

            # aalta trace check
            if validator == 'aalta' or validator == 'both':
                tictoc.tic()
                try:
                    aalta_result = aalta_wrapper.sat(sat_formula, timeout=20)
                    aalta_holds = not aalta_result if aalta_result is not None else None
                except RuntimeError as e:
                    aalta_holds = None
                tictoc.toc('aalta check')
            else:
                aalta_holds = None

            # spot trace check
            if validator == 'spot' or validator == 'both':
                formula_spot = spot.formula(formula_obj.to_str('spot'))
                trace_spot = spot.parse_word(trace_obj.to_str('spot'))
                tictoc.tic()
                formula_automaton = formula_spot.translate()
                trace_automaton = trace_spot.as_automaton()
                tictoc.toc('spot translate')
                tictoc.tic()
                try:
                    spot_holds = spot.contains(
                        formula_automaton, trace_automaton
                    )  # spot.contains checks whether language of its right argument is included in language of its left argument
                except RuntimeError:
                    spot_holds = None
                tictoc.toc('spot contains')
            else:
                spot_holds = None

            # compare, evaluate trace checks
            trace_holds = aalta_holds if aalta_holds is not None else spot_holds  # if both, same, else the one that is there or both None
            if validator == 'both' and aalta_holds != spot_holds:
                print('Formula ', formula_obj.to_str('spot'))
                print('Trace   ', trace_obj.to_str('spot'))
                print('Sat form', sat_formula)
                print('MISMATCH aalta: {} -- spot: {}\n'.format(
                    aalta_holds, spot_holds))
                trace_holds = spot_holds  # trust spot more
            if trace_holds is None:
                if log:
                    log.write(
                        "UNKNOWN {:d}\ninput : {}\noutput: {}\ntarget: {}\n\n".
                        format(
                            line_num, formula_obj.to_str('spot'),
                            trace_obj.to_str('spot'),
                            target_obj.to_str('spot') if target_obj else None))
                increment('unknown', formula_obj)
            elif trace_holds:
                if not sem_desp_syn or (not syntactically_correct):
                    if log and log_level >= 3:
                        log.write(
                            "SEMANTICALLY CORRECT {:d}\ninput : {}\noutput: {}\ntarget: {}\n\n"
                            .format(
                                line_num, formula_obj.to_str('spot'),
                                trace_obj.to_str('spot'),
                                target_obj.to_str('spot')
                                if target_obj else None))
                    increment('only semantically correct', formula_obj)
            else:  # dosen't hold
                increment('incorrect', formula_obj)
                if log and log_level >= 2:
                    log.write(
                        "INCORRECT {:d}\ninput : {}\noutput: {}\ntarget: {}\n\n"
                        .format(
                            line_num, formula_obj.to_str('spot'),
                            trace_obj.to_str('spot'),
                            target_obj.to_str('spot') if target_obj else None))
                if sem_desp_syn and syntactically_correct:
                    raise RuntimeError(
                        'Trace is said to be syntactically correct, but does not fulfil formula!'
                    )

        tictoc.histogram(show=False)
        # evaluation
        if per_size:
            res = per_size_analysis(res, **kwargs)
        res['total'] = line_num
        res['correct'] = res['syntactically correct'] + res[
            'only semantically correct']
        assert res['total'] == res['correct'] + res['incorrect'] + res[
            'invalid'] + res['unknown']
        res_str = "Correct: {:f}%, {correct:d} out of {total:d}\nSyntactically correct: {:f}%, {syntactically correct:d} out of {total:d}\n"\
            "Semantically correct, but not syntactically: {:f}%, {only semantically correct:d} out of {total:d}\n"\
            "Incorrect: {:f}%, {incorrect:d} out of {total:d}\nInvalid: {:f}%, {invalid:d} out of {total:d}\n"\
            "Unknown: {unknown:d} out of {total:d}\n"\
            "".format(res['correct'] / res['total'] * 100, res['syntactically correct'] / res['total'] * 100, res['only semantically correct'] / res['total'] * 100, res['incorrect'] / res['total'] * 100, res['invalid'] / res['total'] * 100, **res)
        if log and not (log is sys.stdout):
            log.write(res_str)
    return res, res_str
Пример #4
0
def generate_samples(num_aps, num_formulas, tree_size, seed, polish, simplify,
                     train_frac, val_frac, unsat_frac, trace_generator,
                     timeout, require_trace, alpha, **kwargs):
    if num_aps > 26:
        raise ValueError("Cannot generate more than 26 APs")
    aps = list(map(chr, range(97, 97 + num_aps)))

    if isinstance(tree_size, int):
        tree_size = (1, tree_size)
    formula_generator = spot.randltl(
        aps,
        seed=seed,
        tree_size=tree_size,
        ltl_priorities=
        'false=1,true=1,not=1,F=0,G=0,X=1,equiv=0,implies=0,xor=0,R=0,U=1,W=0,M=0,and=1,or=0',
        simplify=0)

    tictoc = utils.TicToc()
    dist_gate = DistributionGate('formula size',
                                 'uniform',
                                 tree_size,
                                 num_formulas,
                                 start_calc_from=10,
                                 alpha=alpha)
    global SPOT_WORKER
    SPOT_WORKER = utils.PersistentWorker()

    # generate samples
    print('Generating samples...')
    sat_only = unsat_frac == 0.0
    samples = []
    sat_samples = 0
    unsat_samples = 0
    total_samples = 0
    while total_samples < num_formulas and not dist_gate.full():
        tictoc.tic()
        formula_spot = next(formula_generator)
        tictoc.toc('formula generation')
        formula_str = formula_spot.to_str()
        formula_obj = ltl_parser.ltl_formula(formula_str, 'spot')
        if not dist_gate.gate(formula_obj):  # formula doesn't fit distribution
            continue
        # add some spaces and parenthesis to be safe for aalta
        formula_spaced = formula_obj.to_str('spot',
                                            spacing='all ops',
                                            full_parens=True)
        tictoc.tic()
        is_sat, trace_str = get_sat_and_trace(formula_spaced, trace_generator,
                                              simplify, timeout)
        tictoc.toc('trace generation')

        if is_sat is None:  # due to timeout
            print('Trace generation timed out ({:d}s) for formula {}'.format(
                int(timeout), formula_obj.to_str('spot')))
            if require_trace:
                continue
            else:  # no trace required
                trace_str = '-'
                dist_gate.update(formula_obj)
        elif not is_sat and sat_only:
            continue
        elif not is_sat and not sat_only:
            if unsat_samples >= unsat_frac * num_formulas:
                continue
            else:  # more unsat samples needed
                trace_str = '{0}'
                dist_gate.update(formula_obj)
                unsat_samples += 1
        else:  # is_sat
            if '0' in trace_str:
                print(
                    'Bug in spot! (trace containing 0):\nFormula: {}\nTrace: {}\n'
                    .format(formula_obj.to_str('spot'), trace_str))
                continue
            assert unsat_samples < unsat_frac * \
                num_formulas or not ('0' in trace_str and not sat_only)
            if sat_samples >= (1 - unsat_frac) * num_formulas:
                continue
            else:  # more sat samples needed
                trace_str = ltl_parser.ltl_trace(
                    trace_str,
                    'spot').to_str('network-' +
                                   ('polish' if polish else 'infix'))
                dist_gate.update(formula_obj)
                sat_samples += 1

        formula_str = formula_obj.to_str('network-' +
                                         ('polish' if polish else 'infix'))
        samples.append((formula_str, trace_str))
        if total_samples % (num_formulas // 10) == 0 and total_samples > 0:
            print("%d/%d" % (total_samples, num_formulas))
        total_samples += 1
        sys.stdout.flush()
    # dist_gate.histogram(show=False, save_to='dist_nf{}_ts{:d}-{:d}.png'.format(utils.abbrev_count(num_formulas), tree_size[0], tree_size[1]))      # For distribution analysis
    # tictoc.histogram(show=False, save_to='timing_nf{}_ts{:d}-{:d}.png'.format(utils.abbrev_count(num_formulas), tree_size[0], tree_size[1]))         # For timing analysis
    print('Generated {:d} samples, {:d} requested'.format(
        total_samples, num_formulas))
    SPOT_WORKER.terminate()

    # shuffle and split samples
    random.Random(seed).shuffle(samples)
    res = {}
    res['train'] = samples[0:int(train_frac * total_samples)]
    res['val'] = samples[int(train_frac *
                             total_samples):int((train_frac + val_frac) *
                                                total_samples)]
    res['test'] = samples[int((train_frac + val_frac) * total_samples):]
    return res
Пример #5
0
def test_and_analyze_sat(pred_model, dataset, in_vocab, out_vocab, log_name, **kwargs):
    from deepltl.data.sat_generator import spot_to_pyaiger, is_model

    logdir = path.join(kwargs['job_dir'], kwargs['run_name'])
    tf.io.gfile.makedirs(logdir)
    with open(path.join(logdir, log_name), 'w') as log_file:
        res = {'invalid': 0, 'incorrect': 0, 'syn_correct': 0, 'sem_correct': 0}
        for x in dataset:
            if kwargs['tree_pos_enc']:
                data, pe, label_ = x
                prediction, _ = pred_model([data, pe], training=False)
            else:
                data, label_ = x
                prediction, _ = pred_model(data, training=False)
            for i in range(prediction.shape[0]):
                formula = in_vocab.decode(list(data[i, :]), as_list=True)
                pred = out_vocab.decode(list(prediction[i, :]), as_list=True)
                label = out_vocab.decode(list(label_[i, :]), as_list=True)
                formula_obj = ltl_parser.ltl_formula(''.join(formula), 'network-polish')
                formula_str = formula_obj.to_str('spot')
                _, pretty_label_ass = get_ass(label)
                try:
                    _, pretty_ass = get_ass(pred)
                except ValueError as e:
                    res['invalid'] += 1
                    msg = f"INVALID ({str(e)})\nFormula: {formula_str}\nPred:     {' '.join(pred)}\nLabel:    {pretty_label_ass}\n"
                    log_file.write(msg)
                    continue
                if pred == label:
                    res['syn_correct'] += 1
                    msg = f"SYNTACTICALLY CORRECT\nFormula: {formula_str}\nPred:    {pretty_ass}\nLabel:    {pretty_label_ass}\n"
                    # log_file.write(msg)
                    continue

                # semantic checking
                formula_pyaiger = spot_to_pyaiger(formula)
                ass_pyaiger = spot_to_pyaiger(pred)
                pyaiger_ass_dict, _ = get_ass(ass_pyaiger)
                # print(f'f: {formula_pyaiger}, ass: {pyaiger_ass_dict}')
                try:
                    holds = is_model(formula_pyaiger, pyaiger_ass_dict)
                except KeyError as e:
                    res['incorrect'] += 1
                    msg = f"INCORRECT (var {str(e)} not in formula)\nFormula: {formula_str}\nPred:    {pretty_ass}\nLabel:  {pretty_label_ass}\n"
                    log_file.write(msg)
                    continue
                if holds:
                    res['sem_correct'] += 1
                    msg = f"SEMANTICALLY CORRECT\nFormula: {formula_str}\nPred:    {pretty_ass}\nLabel:  {pretty_label_ass}\n"
                    log_file.write(msg)
                else:
                    res['incorrect'] += 1
                    msg = f"INCORRECT\nFormula: {formula_str}\nPred:    {pretty_ass}\nLabel:   {pretty_label_ass}\n"
                    log_file.write(msg)

        total = sum(res.values())
        correct = res['syn_correct'] + res['sem_correct']
        msg = (f"Correct: {correct/total*100:.1f}%, {correct} out of {total}\nSyntactically correct: {res['syn_correct']/total*100:.1f}%\nSemantically correct: {res['sem_correct']/total*100:.1f}%\n"
               f"Incorrect: {res['incorrect']/total*100:.1f}%\nInvalid: {res['invalid']/total*100:.1f}%\n")
        log_file.write(msg)
        print(msg, end='')