def improve_implementation(impl: Implementation,
                           timeout: datetime.timedelta = datetime.timedelta(
                               seconds=60),
                           progress_callback=None) -> Implementation:

    start_time = datetime.datetime.now()

    # we statefully modify `impl`, so let's make a defensive copy
    impl = Implementation(impl.spec, list(impl.concrete_state),
                          list(impl.query_specs),
                          OrderedDict(impl.query_impls),
                          defaultdict(SNoOp, impl.updates),
                          defaultdict(SNoOp, impl.handle_updates))

    # gather root types
    types = list(all_types(impl.spec))
    basic_types = set(t for t in types if is_scalar(t))
    basic_types |= {BOOL, INT}
    print("basic types:")
    for t in basic_types:
        print("  --> {}".format(pprint(t)))
    basic_types = list(basic_types)
    ctx = SynthCtx(all_types=types, basic_types=basic_types)

    # the actual worker threads
    improvement_jobs = []

    with jobs.SafeQueue() as solutions_q:

        def stop_jobs(js):
            js = list(js)
            jobs.stop_jobs(js)
            for j in js:
                improvement_jobs.remove(j)

        def reconcile_jobs():
            # figure out what new jobs we need
            job_query_names = set(j.q.name for j in improvement_jobs)
            new = []
            for q in impl.query_specs:
                if q.name not in job_query_names:
                    new.append(
                        ImproveQueryJob(
                            ctx,
                            impl.abstract_state,
                            list(impl.spec.assumptions) + list(q.assumptions),
                            q,
                            k=(lambda q: lambda new_rep, new_ret: solutions_q.
                               put((q, new_rep, new_ret)))(q),
                            hints=[
                                EStateVar(c).with_type(c.type) for c in
                                impl.concretization_functions.values()
                            ]))

            # figure out what old jobs we can stop
            impl_query_names = set(q.name for q in impl.query_specs)
            old = [
                j for j in improvement_jobs if j.q.name not in impl_query_names
            ]

            # make it so
            stop_jobs(old)
            for j in new:
                j.start()
            improvement_jobs.extend(new)

        # start jobs
        reconcile_jobs()

        # wait for results
        timeout = Timeout(timeout)
        done = False
        while not done and not timeout.is_timed_out():
            for j in improvement_jobs:
                if j.done:
                    if j.successful:
                        j.join()
                    else:
                        print("failed job: {}".format(j), file=sys.stderr)
                        # raise Exception("failed job: {}".format(j))

            done = all(j.done for j in improvement_jobs)

            try:
                # list of (Query, new_rep, new_ret) objects
                results = solutions_q.drain(block=True, timeout=0.5)
            except Empty:
                continue

            # group by query name, favoring later (i.e. better) solutions
            print("updating with {} new solutions".format(len(results)))
            improved_queries_by_name = OrderedDict()
            killed = 0
            for r in results:
                q, new_rep, new_ret = r
                if q.name in improved_queries_by_name:
                    killed += 1
                improved_queries_by_name[q.name] = r
            if killed:
                print(" --> dropped {} worse solutions".format(killed))

            improvements = list(improved_queries_by_name.values())

            def index_of(l, p):
                if not isinstance(l, list):
                    l = list(l)
                for i in range(len(l)):
                    if p(l[i]):
                        return i
                return -1

            improvements.sort(key=lambda i: index_of(
                impl.query_specs, lambda qq: qq.name == i[0].name))
            print("update order:")
            for (q, _, _) in improvements:
                print("  --> {}".format(q.name))

            # update query implementations
            i = 1
            for (q, new_rep, new_ret) in improvements:
                print("considering update {}/{}...".format(
                    i, len(improvements)))
                i += 1
                # this guard might be false if a better solution was
                # enqueued but the job has already been cleaned up
                if q.name in [qq.name for qq in impl.query_specs]:
                    elapsed = datetime.datetime.now() - start_time
                    print("SOLUTION FOR {} AT {} [size={}]".format(
                        q.name, elapsed,
                        new_ret.size() + sum(proj.size()
                                             for (v, proj) in new_rep)))
                    print("-" * 40)
                    for (sv, proj) in new_rep:
                        print("  {} : {} = {}".format(sv.id, pprint(sv.type),
                                                      pprint(proj)))
                    print("  return {}".format(pprint(new_ret)))
                    print("-" * 40)
                    impl.set_impl(q, new_rep, new_ret)

                    # clean up
                    impl.cleanup()
                    if progress_callback is not None:
                        progress_callback(
                            (impl, impl.code, impl.concretization_functions))
                    reconcile_jobs()

        # stop jobs
        print("Stopping jobs")
        stop_jobs(list(improvement_jobs))
        return impl
Exemple #2
0
def improve_implementation(
        impl: Implementation,
        timeout: datetime.timedelta = datetime.timedelta(seconds=60),
        progress_callback: Callable[[Implementation], Any] = None,
        improve_count: Value = None,
        dump_synthesized_in_file: str = None) -> Implementation:
    """Improve an implementation.

    This function tries to synthesize a better version of the given
    implementation. It returns the best version found within the given timeout.

    If provided, progress_callback will be called whenever a better
    implementation is found.  It will be given the better implementation, which
    it should not modify or cache.

    If provided, the synthesized implementation will be dumped to dump_synthesized_in_file
    before cleaning up the running threads when the loop terminates (because of time-outs etc.).
    This is useful when thread that invokes Z3 is not responsive to cleanup.
    """

    start_time = datetime.datetime.now()

    # we statefully modify `impl`, so let's make a defensive copy which we will modify instead
    impl = impl.safe_copy()

    # worker threads ("jobs"), one per query
    improvement_jobs = []

    with jobs.SafeQueue() as solutions_q:

        def stop_jobs(js):
            """Stop the given jobs and remove them from `improvement_jobs`."""
            js = list(js)
            jobs.stop_jobs(js)
            for j in js:
                improvement_jobs.remove(j)

        def reconcile_jobs():
            """Sync up the current set of jobs and the set of queries.

            This function spawns new jobs for new queries and cleans up old
            jobs whose queries have been dead-code-eliminated."""

            # figure out what new jobs we need
            job_query_names = set(j.q.name for j in improvement_jobs)
            new = []
            for q in impl.query_specs:
                if q.name not in job_query_names:
                    states_maintained_by_q = impl.states_maintained_by(q)
                    print("STARTING IMPROVEMENT JOB {}".format(q.name))
                    new.append(
                        ImproveQueryJob(
                            impl.abstract_state,
                            list(impl.spec.assumptions) + list(q.assumptions),
                            q,
                            context=impl.context_for_method(q),
                            solutions_q=solutions_q.handle_for_subjobs(),
                            hints=[
                                EStateVar(c).with_type(c.type) for c in
                                impl.concretization_functions.values()
                            ],
                            freebies=[
                                e for (v, e) in
                                impl.concretization_functions.items()
                                if EVar(v) in states_maintained_by_q
                            ],
                            ops=impl.op_specs,
                            improve_count=improve_count))

            # figure out what old jobs we can stop
            impl_query_names = set(q.name for q in impl.query_specs)
            old = [
                j for j in improvement_jobs if j.q.name not in impl_query_names
            ]

            # make it so
            stop_jobs(old)
            for j in new:
                j.start()
            improvement_jobs.extend(new)

        # start jobs
        reconcile_jobs()

        # wait for results
        timeout = Timeout(timeout)
        done = False
        while not done and not timeout.is_timed_out(
        ) and not jobs.was_interrupted():
            for j in improvement_jobs:
                if j.done:
                    if j.successful:
                        j.join()
                    else:
                        print("failed job: {}".format(j), file=sys.stderr)
                        # raise Exception("failed job: {}".format(j))

            done = all(j.done for j in improvement_jobs)

            try:
                # list of (Query, packed_expr) objects
                results = solutions_q.drain(block=True, timeout=0.5)
            except Empty:
                continue

            # group by query name, favoring later (i.e. better) solutions
            print("updating with {} new solutions".format(len(results)))
            improved_queries_by_name = OrderedDict()
            killed = 0
            for r in results:
                q, packed_expr = r
                if q.name in improved_queries_by_name:
                    killed += 1
                improved_queries_by_name[q.name] = r
            if killed:
                print(" --> dropped {} worse solutions".format(killed))

            improvements = list(improved_queries_by_name.values())

            def index_of(l, p):
                if not isinstance(l, list):
                    l = list(l)
                for i in range(len(l)):
                    if p(l[i]):
                        return i
                return -1

            improvements.sort(key=lambda i: index_of(
                impl.query_specs, lambda qq: qq.name == i[0].name))
            print("update order:")
            for (q, _) in improvements:
                print("  --> {}".format(q.name))

            # update query implementations
            i = 1
            for (q, packed_expr) in improvements:
                if timeout.is_timed_out():
                    break

                print("considering update {}/{}...".format(
                    i, len(improvements)))
                i += 1
                # The guard on the next line might be false!
                # It might so happen that:
                #   - a job found a better version for q
                #   - a different job found a better version of some other query X
                #   - both improvements were in the `results` list pulled from the queue
                #   - we visited the improvement for X first
                #   - after cleanup, q is no longer needed and was removed
                if q.name in [qq.name for qq in impl.query_specs]:
                    new_rep, new_ret = unpack_representation(packed_expr)
                    elapsed = datetime.datetime.now() - start_time
                    print("SOLUTION FOR {} AT {} [size={}]".format(
                        q.name, elapsed,
                        new_ret.size() + sum(proj.size()
                                             for (v, proj) in new_rep)))
                    print("-" * 40)
                    for (sv, proj) in new_rep:
                        print("  {} : {} = {}".format(sv.id, pprint(sv.type),
                                                      pprint(proj)))
                    print("  return {}".format(pprint(new_ret)))
                    print("-" * 40)
                    impl.set_impl(q, new_rep, new_ret)

                    # clean up
                    impl.cleanup()
                    if progress_callback is not None:
                        progress_callback(impl)
                    reconcile_jobs()
                else:
                    print("  (skipped; {} was aleady cleaned up)".format(
                        q.name))

        if dump_synthesized_in_file is not None:
            with open(dump_synthesized_in_file, "wb") as f:
                pickle.dump(impl, f)
                print("Dumped implementation to file {}".format(
                    dump_synthesized_in_file))

        # stop jobs
        print("Stopping jobs")
        stop_jobs(list(improvement_jobs))
        return impl
Exemple #3
0
def improve_implementation(
        impl              : Implementation,
        timeout           : datetime.timedelta = datetime.timedelta(seconds=60),
        progress_callback : Callable[[Implementation], Any] = None,
        improve_count     : Value = None) -> Implementation:
    """Improve an implementation.

    This function tries to synthesize a better version of the given
    implementation. It returns the best version found within the given timeout.

    If provided, progress_callback will be called whenever a better
    implementation is found.  It will be given the better implementation, which
    it should not modify or cache.
    """

    start_time = datetime.datetime.now()

    # we statefully modify `impl`, so let's make a defensive copy which we will modify instead
    impl = impl.safe_copy()

    # worker threads ("jobs"), one per query
    improvement_jobs = []

    with jobs.SafeQueue() as solutions_q:

        def stop_jobs(js):
            """Stop the given jobs and remove them from `improvement_jobs`."""
            js = list(js)
            jobs.stop_jobs(js)
            for j in js:
                improvement_jobs.remove(j)

        def reconcile_jobs():
            """Sync up the current set of jobs and the set of queries.

            This function spawns new jobs for new queries and cleans up old
            jobs whose queries have been dead-code-eliminated."""

            # figure out what new jobs we need
            job_query_names  = set(j.q.name for j in improvement_jobs)
            new = []
            for q in impl.query_specs:
                if q.name not in job_query_names:
                    states_maintained_by_q = impl.states_maintained_by(q)
                    new.append(ImproveQueryJob(
                        impl.abstract_state,
                        list(impl.spec.assumptions) + list(q.assumptions),
                        q,
                        context=impl.context_for_method(q),
                        k=(lambda q: lambda new_rep, new_ret: solutions_q.put((q, new_rep, new_ret)))(q),
                        hints=[EStateVar(c).with_type(c.type) for c in impl.concretization_functions.values()],
                        freebies=[e for (v, e) in impl.concretization_functions.items() if EVar(v) in states_maintained_by_q],
                        ops=impl.op_specs,
                        improve_count=improve_count))

            # figure out what old jobs we can stop
            impl_query_names = set(q.name for q in impl.query_specs)
            old = [j for j in improvement_jobs if j.q.name not in impl_query_names]

            # make it so
            stop_jobs(old)
            for j in new:
                j.start()
            improvement_jobs.extend(new)

        # start jobs
        reconcile_jobs()

        # wait for results
        timeout = Timeout(timeout)
        done = False
        while not done and not timeout.is_timed_out():
            for j in improvement_jobs:
                if j.done:
                    if j.successful:
                        j.join()
                    else:
                        print("failed job: {}".format(j), file=sys.stderr)
                        # raise Exception("failed job: {}".format(j))

            done = all(j.done for j in improvement_jobs)

            try:
                # list of (Query, new_rep, new_ret) objects
                results = solutions_q.drain(block=True, timeout=0.5)
            except Empty:
                continue

            # group by query name, favoring later (i.e. better) solutions
            print("updating with {} new solutions".format(len(results)))
            improved_queries_by_name = OrderedDict()
            killed = 0
            for r in results:
                q, new_rep, new_ret = r
                if q.name in improved_queries_by_name:
                    killed += 1
                improved_queries_by_name[q.name] = r
            if killed:
                print(" --> dropped {} worse solutions".format(killed))

            improvements = list(improved_queries_by_name.values())
            def index_of(l, p):
                if not isinstance(l, list):
                    l = list(l)
                for i in range(len(l)):
                    if p(l[i]):
                        return i
                return -1
            improvements.sort(key = lambda i: index_of(impl.query_specs, lambda qq: qq.name == i[0].name))
            print("update order:")
            for (q, _, _) in improvements:
                print("  --> {}".format(q.name))

            # update query implementations
            i = 1
            for (q, new_rep, new_ret) in improvements:
                if timeout.is_timed_out():
                    break

                print("considering update {}/{}...".format(i, len(improvements)))
                i += 1
                # The guard on the next line might be false!
                # It might so happen that:
                #   - a job found a better version for q
                #   - a different job found a better version of some other query X
                #   - both improvements were in the `results` list pulled from the queue
                #   - we visited the improvement for X first
                #   - after cleanup, q is no longer needed and was removed
                if q.name in [qq.name for qq in impl.query_specs]:
                    elapsed = datetime.datetime.now() - start_time
                    print("SOLUTION FOR {} AT {} [size={}]".format(q.name, elapsed, new_ret.size() + sum(proj.size() for (v, proj) in new_rep)))
                    print("-" * 40)
                    for (sv, proj) in new_rep:
                        print("  {} : {} = {}".format(sv.id, pprint(sv.type), pprint(proj)))
                    print("  return {}".format(pprint(new_ret)))
                    print("-" * 40)
                    impl.set_impl(q, new_rep, new_ret)

                    # clean up
                    impl.cleanup()
                    if progress_callback is not None:
                        progress_callback(impl)
                    reconcile_jobs()
                else:
                    print("  (skipped; {} was aleady cleaned up)".format(q.name))

        # stop jobs
        print("Stopping jobs")
        stop_jobs(list(improvement_jobs))
        return impl