예제 #1
0
파일: test_timer.py 프로젝트: tlammi/kratos
def test_invalid_stop():
    t = Timer()
    
    t.stop()

    assert t.running == False
    assert t.remaining_time == 0
예제 #2
0
파일: test_timer.py 프로젝트: tlammi/kratos
def test_remaining_time():
    t = Timer()

    t.start(timeout=1)
    time.sleep(0.1)
    t.stop()

    assert t.remaining_time == 0
예제 #3
0
파일: test_timer.py 프로젝트: tlammi/kratos
def test_multiple_starts():
    t = Timer()

    t.start(timeout=2)

    assert t.running == True
    
    with pytest.raises(RuntimeError, match="Timer is already running"):
        t.start()

    assert t.running == True
    time.sleep(0.1)

    t.stop()
예제 #4
0
class FieldProtector:
    PROTECTED = 'protected'
    NOT_PROTECTED = 'not_protected'
    BLINKING = 'blinking'

    def __init__(self, field: Field):
        self.field = field
        self._blink_animator = Animator(delay=1, max_states=2)
        self._protected_timer = Timer(delay=15)
        self._blink_timer = Timer(delay=6)
        self._state = self.NOT_PROTECTED

    def update(self):
        if self._state == self.PROTECTED:
            if self._protected_timer.tick():
                self._state = self.BLINKING
                self._blink_timer.start()
        elif self._state == self.BLINKING:
            if self._blink_timer.tick():
                self._change_base_border_tye(CellType.BRICK)
                self._state = self.NOT_PROTECTED
            else:
                state = self._blink_animator()
                self._change_base_border_tye(
                    CellType.BRICK if state else CellType.CONCRETE)

    @property
    def cells_around_base(self):
        return [(11, 25), (11, 24), (11, 23), (12, 23), (13, 23), (14, 23),
                (14, 24), (14, 25)]

    def _change_base_border_tye(self, ct: CellType):
        for x, y in self.cells_around_base:
            self.field.map.set_cell_col_row(x, y, ct)

    def activate(self):
        self._state = self.PROTECTED

        self._blink_timer.stop()
        self._protected_timer.start()

        self._change_base_border_tye(CellType.CONCRETE)

        # 1. защитить базу бетоном
        # 2. запустить таймер на 20 сек
        # 3. когда таймер кончится - запустить аниматор и таймер мигания на 10 сек
        # 4. пока таймер мигания - каждую секунду меняем щит с бетона на кирпич и обратно!
        ...
예제 #5
0
파일: test_timer.py 프로젝트: tlammi/kratos
def test_start_after_stop():
    t = Timer()

    t.start(timeout=2)

    time.sleep(0.1)

    t.stop()

    t.start()

    assert t.running == True
    assert t.remaining_time > 0

    time.sleep(0.1)

    t.stop()
예제 #6
0
파일: test_timer.py 프로젝트: tlammi/kratos
def test_clearing():
    t = Timer()

    t.start(timeout=2)
    time.sleep(0.1)
    t.stop()

    t.clear()

    assert t.running == False
    assert t.remaining_time == 0

    with pytest.raises(ValueError, match="No timeout value stored, please provide one"):
        t.start()

    t.start(timeout=2)
    time.sleep(0.1)
    t.stop()
예제 #7
0
파일: test_timer.py 프로젝트: tlammi/kratos
def test_multiple_stops():
    t = Timer()

    t.start(timeout=2)

    time.sleep(0.1)

    t.stop()

    status = t.running
    remaining = t.remaining_time

    assert status == False
    assert remaining > 0

    t.stop()

    assert t.running == status
    assert t.remaining_time == remaining
예제 #8
0
파일: server.py 프로젝트: twareproj/tware
    def do_POST(self):
        t = Timer()
        t.start()
        response = 200
        result = {}
        try:
            content_length = int(self.headers.getheader('content-length'))
            req = json.loads(self.rfile.read(content_length))
            print req

            req_type = req['type']
            result = None
            if req_type == 'catalog':
                result = json.dumps(self.server.catalog)
            elif req_type == 'execute':
                task = req['args']['task']
                json.dumps(BasicExecutor(self.server.cache, task).execute())
            elif req_type == 'lookup':
                uuid = req['args']['uuid']
                result = self.server.cache[uuid]
                if type(result) is pd.DataFrame:
                    page_size = int(req['args']['page_size'])
                    page_num = int(req['args']['page_num'])
                    i = page_size * page_num
                    j = i + page_size
                    result = result[i:j]
                result = result.to_json()
        except:
            print traceback.format_exc()
            response = 500
            result = '{}'
        t.stop()

        self.send_response(response)
        self.send_header('Content-type','application/json')
        self.end_headers()
        self.wfile.write(result)

        print 'Run Time:', t.time()
예제 #9
0
    def do_POST(self):
        t = Timer()
        t.start()
        response = 200
        result = {}
        try:
            content_length = int(self.headers.getheader('content-length'))
            req = json.loads(self.rfile.read(content_length))
            print req

            req_type = req['type']
            result = None
            if req_type == 'catalog':
                result = json.dumps(self.server.catalog)
            elif req_type == 'execute':
                task = req['args']['task']
                json.dumps(BasicExecutor(self.server.cache, task).execute())
            elif req_type == 'lookup':
                uuid = req['args']['uuid']
                result = self.server.cache[uuid]
                if type(result) is pd.DataFrame:
                    page_size = int(req['args']['page_size'])
                    page_num = int(req['args']['page_num'])
                    i = page_size * page_num
                    j = i + page_size
                    result = result[i:j]
                result = result.to_json()
        except:
            print traceback.format_exc()
            response = 500
            result = '{}'
        t.stop()

        self.send_response(response)
        self.send_header('Content-type', 'application/json')
        self.end_headers()
        self.wfile.write(result)

        print 'Run Time:', t.time()
    def generate_sudoku(self, target = 25):
        search = AStarSearch()

        base_sudoku = self.generate_full_sudoku()
        timer = Timer()
        
        if self.__kind == 'reverse':
            problem = ReverseSudokuGenerationProblem(Sudoku(), target, self.solver)
        else:
            problem = SudokuGenerationProblem(base_sudoku, target, self.solver)

        timer.start()
        node, cnt_explored = search.search(problem, h = lambda n: problem.value(n.state))
        time = timer.stop()
        return node.state, len(node.state), cnt_explored, time
예제 #11
0
class DocTestController(SageObject):
    """
    This class controls doctesting of files.

    After creating it with appropriate options, call the :meth:run() method to run the doctests.
    """
    def __init__(self, options, args):
        """
        Initialization.

        INPUT:

        - options -- either options generated from the command line by SAGE_LOCAL/bin/sage-runtests
                     or a DocTestDefaults object (possibly with some entries modified)
        - args -- a list of filenames to doctest

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DC = DocTestController(DocTestDefaults(), [])
            sage: DC
            DocTest Controller
        """
        # First we modify options to take environment variables into
        # account and check compatibility of the user's specified
        # options.
        if options.timeout < 0:
            if options.gdb or options.debug:
                # Interactive debuggers: "infinite" timeout
                options.timeout = 0
            elif options.valgrind or options.massif or options.cachegrind or options.omega:
                # Non-interactive debuggers: 48 hours
                options.timeout = int(
                    os.getenv('SAGE_TIMEOUT_VALGRIND', 48 * 60 * 60))
            elif options.long:
                options.timeout = int(os.getenv('SAGE_TIMEOUT_LONG', 30 * 60))
            else:
                options.timeout = int(os.getenv('SAGE_TIMEOUT', 5 * 60))
        if options.nthreads == 0:
            options.nthreads = int(os.getenv('SAGE_NUM_THREADS_PARALLEL', 1))
        if options.failed and not (args or options.new or options.sagenb):
            # If the user doesn't specify any files then we rerun all failed files.
            options.all = True
        if options.global_iterations == 0:
            options.global_iterations = int(
                os.environ.get('SAGE_TEST_GLOBAL_ITER', 1))
        if options.file_iterations == 0:
            options.file_iterations = int(os.environ.get('SAGE_TEST_ITER', 1))
        if options.debug and options.nthreads > 1:
            print(
                "Debugging requires single-threaded operation, setting number of threads to 1."
            )
            options.nthreads = 1
        if options.serial:
            options.nthreads = 1

        self.options = options
        self.files = args
        if options.all and options.logfile is None:
            options.logfile = os.path.join(os.environ['SAGE_TESTDIR'],
                                           'test.log')
        if options.logfile:
            try:
                self.logfile = open(options.logfile, 'a')
            except IOError:
                print "Unable to open logfile at %s\nProceeding without logging." % (
                    options.logfile)
                self.logfile = None
        else:
            self.logfile = None
        self.stats = {}
        self.load_stats(options.stats_path)

    def _repr_(self):
        """
        String representation.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DC = DocTestController(DocTestDefaults(), [])
            sage: repr(DC) # indirect doctest
            'DocTest Controller'
        """
        return "DocTest Controller"

    def load_stats(self, filename):
        """
        Load stats from the most recent run(s).

        Stats are stored as a JSON file, and include information on
        which files failed tests and the walltime used for execution
        of the doctests.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DC = DocTestController(DocTestDefaults(), [])
            sage: import json
            sage: filename = tmp_filename()
            sage: with open(filename, 'w') as stats_file:
            ...       json.dump({'sage.doctest.control':{u'walltime':1.0r}}, stats_file)
            sage: DC.load_stats(filename)
            sage: DC.stats['sage.doctest.control']
            {u'walltime': 1.0}

        If the file doesn't exist, nothing happens. If there is an
        error, print a message. In any case, leave the stats alone::

            sage: d = tmp_dir()
            sage: DC.load_stats(os.path.join(d))  # Cannot read a directory
            Error loading stats from ...
            sage: DC.load_stats(os.path.join(d, "no_such_file"))
            sage: DC.stats['sage.doctest.control']
            {u'walltime': 1.0}
        """
        # Simply ignore non-existing files
        if not os.path.exists(filename):
            return

        try:
            with open(filename) as stats_file:
                self.stats.update(json.load(stats_file))
        except StandardError:
            self.log("Error loading stats from %s" % filename)

    def save_stats(self, filename):
        """
        Save stats from the most recent run as a JSON file.

        WARNING: This function overwrites the file.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DC = DocTestController(DocTestDefaults(), [])
            sage: DC.stats['sage.doctest.control'] = {u'walltime':1.0r}
            sage: filename = tmp_filename()
            sage: DC.save_stats(filename)
            sage: import json
            sage: D = json.load(open(filename))
            sage: D['sage.doctest.control']
            {u'walltime': 1.0}
        """
        with open(filename, 'w') as stats_file:
            json.dump(self.stats, stats_file)

    def log(self, s, end="\n"):
        """
        Logs the string ``s + end`` (where ``end`` is a newline by default)
        to the logfile and prints it to the standard output.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DD = DocTestDefaults(logfile=tmp_filename())
            sage: DC = DocTestController(DD, [])
            sage: DC.log("hello world")
            hello world
            sage: DC.logfile.close()
            sage: with open(DD.logfile) as logger: print logger.read()
            hello world

        """
        s += end
        if self.logfile is not None:
            self.logfile.write(s)
        sys.stdout.write(s)

    def test_safe_directory(self, dir=None):
        """
        Test that the given directory is safe to run Python code from.

        We use the check added to Python for this, which gives a
        warning when the current directory is considered unsafe.  We promote
        this warning to an error with ``-Werror``.  See
        ``sage/tests/cmdline.py`` for a doctest that this works, see
        also :trac:`13579`.

        TESTS::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DD = DocTestDefaults()
            sage: DC = DocTestController(DD, [])
            sage: DC.test_safe_directory()
            sage: d = os.path.join(tmp_dir(), "test")
            sage: os.mkdir(d)
            sage: os.chmod(d, 0o777)
            sage: DC.test_safe_directory(d)
            Traceback (most recent call last):
            ...
            RuntimeError: refusing to run doctests...
        """
        import subprocess
        with open(os.devnull, 'w') as dev_null:
            if subprocess.call(['python', '-Werror', '-c', ''],
                               stdout=dev_null,
                               stderr=dev_null,
                               cwd=dir) != 0:
                raise RuntimeError(
                    "refusing to run doctests from the current "
                    "directory '{}' since untrusted users could put files in "
                    "this directory, making it unsafe to run Sage code from".
                    format(os.getcwd()))

    def create_run_id(self):
        """
        Creates the run id.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DC = DocTestController(DocTestDefaults(), [])
            sage: DC.create_run_id()
            Running doctests with ID ...
        """
        self.run_id = time.strftime(
            '%Y-%m-%d-%H-%M-%S-') + "%08x" % random.getrandbits(32)
        from sage.version import version
        self.log("Running doctests with ID %s." % self.run_id)

    def add_files(self):
        """
        Checks for the flags '--all', '--new' and '--sagenb'.

        For each one present, this function adds the appropriate directories and files to the todo list.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: from sage.env import SAGE_SRC
            sage: import os
            sage: log_location = os.path.join(SAGE_TMP, 'control_dt_log.log')
            sage: DD = DocTestDefaults(all=True, logfile=log_location)
            sage: DC = DocTestController(DD, [])
            sage: DC.add_files()
            Doctesting entire Sage library.
            sage: os.path.join(SAGE_SRC, 'sage') in DC.files
            True

        ::

            sage: DD = DocTestDefaults(new = True)
            sage: DC = DocTestController(DD, [])
            sage: DC.add_files()
            Doctesting files changed since last HG commit.
            sage: len(DC.files) == len([L for L in hg_sage('status', interactive=False, debug=False)[0].split('\n') if len(L.split()) ==2 and L.split()[0] in ['M','A']])
            True

        ::

            sage: DD = DocTestDefaults(sagenb = True)
            sage: DC = DocTestController(DD, [])
            sage: DC.add_files()
            Doctesting the Sage notebook.
            sage: DC.files[0][-6:]
            'sagenb'
        """
        opj = os.path.join
        from sage.env import SAGE_SRC as base
        if self.options.all:
            self.log("Doctesting entire Sage library.")
            from glob import glob
            self.files.append(opj(base, 'sage'))
            self.files.append(opj(base, 'doc', 'common'))
            self.files.extend(glob(opj(base, 'doc', '[a-z][a-z]')))
            self.options.sagenb = True
        elif self.options.new:
            self.log("Doctesting files changed since last HG commit.")
            import sage.all_cmdline
            from sage.misc.hg import hg_sage
            for X in hg_sage('status', interactive=False,
                             debug=False)[0].split('\n'):
                tup = X.split()
                if len(tup) != 2: continue
                c, filename = tup
                if c in ['M', 'A']:
                    filename = opj(base, filename)
                    self.files.append(filename)
        if self.options.sagenb:
            if not self.options.all:
                self.log("Doctesting the Sage notebook.")
            from pkg_resources import Requirement, working_set
            sagenb_loc = working_set.find(Requirement.parse('sagenb')).location
            self.files.append(opj(sagenb_loc, 'sagenb'))

    def expand_files_into_sources(self):
        """
        Expands ``self.files``, which may include directories, into a
        list of :class:`sage.doctest.FileDocTestSource`

        This function also handles the optional command line option.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: from sage.env import SAGE_SRC
            sage: import os
            sage: dirname = os.path.join(SAGE_SRC, 'sage', 'doctest')
            sage: DD = DocTestDefaults(optional='all')
            sage: DC = DocTestController(DD, [dirname])
            sage: DC.expand_files_into_sources()
            sage: len(DC.sources)
            9
            sage: DC.sources[0].optional
            True

        ::

            sage: DD = DocTestDefaults(optional='magma,guava')
            sage: DC = DocTestController(DD, [dirname])
            sage: DC.expand_files_into_sources()
            sage: sorted(list(DC.sources[0].optional))
            ['guava', 'magma']
        """
        def skipdir(dirname):
            if os.path.exists(os.path.join(dirname, "nodoctest.py")):
                return True
            # Workaround for https://github.com/sagemath/sagenb/pull/84
            if dirname.endswith(os.path.join(os.sep, 'sagenb', 'data')):
                return True
            return False

        def skipfile(filename):
            base, ext = os.path.splitext(filename)
            if ext not in ('.py', '.pyx', '.pxi', '.sage', '.spyx', '.rst',
                           '.tex'):
                return True
            with open(filename) as F:
                return 'nodoctest' in F.read(50)

        def expand():
            for path in self.files:
                if os.path.isdir(path):
                    for root, dirs, files in os.walk(path):
                        for dir in list(dirs):
                            if dir[0] == "." or skipdir(os.path.join(
                                    root, dir)):
                                dirs.remove(dir)
                        for file in files:
                            if not skipfile(os.path.join(root, file)):
                                yield os.path.join(root, file)
                else:
                    # the user input this file explicitly, so we don't skip it
                    yield path

        if self.options.optional == 'all':
            optionals = True
        else:
            optionals = set(self.options.optional.lower().split(','))
        self.sources = [
            FileDocTestSource(path,
                              self.options.force_lib,
                              long=self.options.long,
                              optional=optionals,
                              randorder=self.options.randorder,
                              useabspath=self.options.abspath)
            for path in expand()
        ]

    def filter_sources(self):
        """

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: from sage.env import SAGE_SRC
            sage: import os
            sage: dirname = os.path.join(SAGE_SRC, 'sage', 'doctest')
            sage: DD = DocTestDefaults(failed=True)
            sage: DC = DocTestController(DD, [dirname])
            sage: DC.expand_files_into_sources()
            sage: for i, source in enumerate(DC.sources):
            ...       DC.stats[source.basename] = {'walltime': 0.1*(i+1)}
            sage: DC.stats['sage.doctest.control'] = {'failed':True,'walltime':1.0}
            sage: DC.filter_sources()
            Only doctesting files that failed last test.
            sage: len(DC.sources)
            1
        """
        # Filter the sources to only include those with failing doctests if the --failed option is passed
        if self.options.failed:
            self.log("Only doctesting files that failed last test.")

            def is_failure(source):
                basename = source.basename
                return basename not in self.stats or self.stats[basename].get(
                    'failed')

            self.sources = filter(is_failure, self.sources)

    def sort_sources(self):
        """
        This function sorts the sources so that slower doctests are run first.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: from sage.env import SAGE_SRC
            sage: import os
            sage: dirname = os.path.join(SAGE_SRC, 'sage', 'doctest')
            sage: DD = DocTestDefaults(nthreads=2)
            sage: DC = DocTestController(DD, [dirname])
            sage: DC.expand_files_into_sources()
            sage: DC.sources.sort(key=lambda s:s.basename)
            sage: for i, source in enumerate(DC.sources):
            ...       DC.stats[source.basename] = {'walltime': 0.1*(i+1)}
            sage: DC.sort_sources()
            Sorting sources by runtime so that slower doctests are run first....
            sage: print "\n".join([source.basename for source in DC.sources])
            sage.doctest.util
            sage.doctest.test
            sage.doctest.sources
            sage.doctest.reporting
            sage.doctest.parsing
            sage.doctest.forker
            sage.doctest.control
            sage.doctest.all
            sage.doctest
        """
        if self.options.nthreads > 1 and len(
                self.sources) > self.options.nthreads:
            self.log(
                "Sorting sources by runtime so that slower doctests are run first...."
            )
            default = dict(walltime=0)

            def sort_key(source):
                basename = source.basename
                return -self.stats.get(basename,
                                       default).get('walltime'), basename

            self.sources = [
                x[1] for x in sorted((sort_key(source), source)
                                     for source in self.sources)
            ]

    def run_doctests(self):
        """
        Actually runs the doctests.

        This function is called by :meth:run().

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: from sage.env import SAGE_SRC
            sage: import os
            sage: dirname = os.path.join(SAGE_SRC, 'sage', 'rings', 'homset.py')
            sage: DD = DocTestDefaults()
            sage: DC = DocTestController(DD, [dirname])
            sage: DC.expand_files_into_sources()
            sage: DC.run_doctests()
            Doctesting 1 file.
            sage -t .../sage/rings/homset.py
                [... tests, ... s]
            ----------------------------------------------------------------------
            All tests passed!
            ----------------------------------------------------------------------
            Total time for all tests: ... seconds
                cpu time: ... seconds
                cumulative wall time: ... seconds
        """
        nfiles = 0
        nother = 0
        for F in self.sources:
            if isinstance(F, FileDocTestSource):
                nfiles += 1
            else:
                nother += 1
        if self.sources:
            filestr = ", ".join(
                ([count_noun(nfiles, "file")] if nfiles else []) +
                ([count_noun(nother, "other source")] if nother else []))
            threads = " using %s threads" % (
                self.options.nthreads) if self.options.nthreads > 1 else ""
            iterations = []
            if self.options.global_iterations > 1:
                iterations.append("%s global iterations" %
                                  (self.options.global_iterations))
            if self.options.file_iterations > 1:
                iterations.append("%s file iterations" %
                                  (self.options.file_iterations))
            iterations = ", ".join(iterations)
            if iterations:
                iterations = " (%s)" % (iterations)
            self.log("Doctesting %s%s%s." % (filestr, threads, iterations))
            self.reporter = DocTestReporter(self)
            self.dispatcher = DocTestDispatcher(self)
            N = self.options.global_iterations
            for it in range(N):
                try:
                    self.timer = Timer().start()
                    self.dispatcher.dispatch()
                except KeyboardInterrupt:
                    it = N - 1
                    break
                finally:
                    self.timer.stop()
                    self.reporter.finalize()
                    self.cleanup(it == N - 1)
        else:
            self.log("No files to doctest")
            self.reporter = DictAsObject(dict(error_status=0))

    def cleanup(self, final=True):
        """
        Runs cleanup activities after actually running doctests.

        In particular, saves the stats to disk and closes the logfile.

        INPUT:

        - ``final`` -- whether to close the logfile

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: from sage.env import SAGE_SRC
            sage: import os
            sage: dirname = os.path.join(SAGE_SRC, 'sage', 'rings', 'infinity.py')
            sage: DD = DocTestDefaults()

            sage: DC = DocTestController(DD, [dirname])
            sage: DC.expand_files_into_sources()
            sage: DC.sources.sort(key=lambda s:s.basename)

            sage: for i, source in enumerate(DC.sources):
            ....:     DC.stats[source.basename] = {'walltime': 0.1*(i+1)}
            ....:

            sage: DC.run()
            Running doctests with ID ...
            Doctesting 1 file.
            sage -t .../rings/infinity.py
                [... tests, ... s]
            ----------------------------------------------------------------------
            All tests passed!
            ----------------------------------------------------------------------
            Total time for all tests: ... seconds
                cpu time: ... seconds
                cumulative wall time: ... seconds
            0
            sage: DC.cleanup()
        """
        self.stats.update(self.reporter.stats)
        self.save_stats(self.options.stats_path)
        # Close the logfile
        if final and self.logfile is not None:
            self.logfile.close()
            self.logfile = None

    def _assemble_cmd(self):
        """
        Assembles a shell command used in running tests under gdb or valgrind.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DC = DocTestController(DocTestDefaults(timeout=123), ["hello_world.py"])
            sage: print DC._assemble_cmd()
            python "$SAGE_LOCAL/bin/sage-runtests" --serial --timeout=123 hello_world.py
        """
        cmd = '''python "%s" --serial ''' % (os.path.join(
            "$SAGE_LOCAL", "bin", "sage-runtests"))
        opt = dict_difference(self.options.__dict__,
                              DocTestDefaults().__dict__)
        for o in ("all", "sagenb"):
            if o in opt:
                raise ValueError(
                    "You cannot run gdb/valgrind on the whole sage%s library" %
                    ("" if o == "all" else "nb"))
        for o in ("all", "sagenb", "long", "force_lib", "verbose", "failed",
                  "new"):
            if o in opt:
                cmd += "--%s " % o
        for o in ("timeout", "optional", "randorder", "stats_path"):
            if o in opt:
                cmd += "--%s=%s " % (o, opt[o])
        return cmd + " ".join(self.files)

    def run_val_gdb(self, testing=False):
        """
        Spawns a subprocess to run tests under the control of gdb or valgrind.

        INPUT:

        - ``testing`` -- boolean; if True then the command to be run
          will be printed rather than a subprocess started.

        EXAMPLES:

        Note that the command lines include unexpanded environment
        variables. It is safer to let the shell expand them than to
        expand them here and risk insufficient quoting. ::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DD = DocTestDefaults(gdb=True)
            sage: DC = DocTestController(DD, ["hello_world.py"])
            sage: DC.run_val_gdb(testing=True)
            exec gdb -x "$SAGE_LOCAL/bin/sage-gdb-commands" --args python "$SAGE_LOCAL/bin/sage-runtests" --serial --timeout=0 hello_world.py

        ::

            sage: DD = DocTestDefaults(valgrind=True, optional="all", timeout=172800)
            sage: DC = DocTestController(DD, ["hello_world.py"])
            sage: DC.run_val_gdb(testing=True)
            exec valgrind --tool=memcheck --leak-resolution=high --leak-check=full --num-callers=25 --suppressions="$SAGE_LOCAL/lib/valgrind/sage.supp"  --log-file=".../valgrind/sage-memcheck.%p" python "$SAGE_LOCAL/bin/sage-runtests" --serial --timeout=172800 --optional=all hello_world.py
        """
        try:
            sage_cmd = self._assemble_cmd()
        except ValueError:
            self.log(sys.exc_info()[1])
            return 2
        opt = self.options
        if opt.gdb:
            cmd = '''exec gdb -x "$SAGE_LOCAL/bin/sage-gdb-commands" --args '''
            flags = ""
            if opt.logfile:
                sage_cmd += " --logfile %s" % (opt.logfile)
        else:
            if opt.logfile is None:
                default_log = os.path.join(DOT_SAGE, "valgrind")
                if os.path.exists(default_log):
                    if not os.path.isdir(default_log):
                        self.log("%s must be a directory" % default_log)
                        return 2
                else:
                    os.makedirs(default_log)
                logfile = os.path.join(default_log, "sage-%s")
            else:
                logfile = opt.logfile
            if opt.valgrind:
                toolname = "memcheck"
                flags = os.getenv("SAGE_MEMCHECK_FLAGS")
                if flags is None:
                    flags = "--leak-resolution=high --leak-check=full --num-callers=25 "
                    flags += '''--suppressions="%s" ''' % (os.path.join(
                        "$SAGE_LOCAL", "lib", "valgrind", "sage.supp"))
            elif opt.massif:
                toolname = "massif"
                flags = os.getenv("SAGE_MASSIF_FLAGS", "--depth=6 ")
            elif opt.cachegrind:
                toolname = "cachegrind"
                flags = os.getenv("SAGE_CACHEGRIND_FLAGS", "")
            elif opt.omega:
                toolname = "exp-omega"
                flags = os.getenv("SAGE_OMEGA_FLAGS", "")
            cmd = "exec valgrind --tool=%s " % (toolname)
            flags += ''' --log-file="%s" ''' % logfile
            if opt.omega:
                toolname = "omega"
            if "%s" in flags:
                flags %= toolname + ".%p"  # replace %s with toolname
        cmd += flags + sage_cmd

        self.log(cmd)
        sys.stdout.flush()
        sys.stderr.flush()
        if self.logfile is not None:
            self.logfile.flush()

        if testing:
            return

        import signal, subprocess

        def handle_alrm(sig, frame):
            raise RuntimeError

        signal.signal(signal.SIGALRM, handle_alrm)
        p = subprocess.Popen(cmd, shell=True)
        if opt.timeout > 0:
            signal.alarm(opt.timeout)
        try:
            return p.wait()
        except RuntimeError:
            self.log("    Time out")
            return 4
        except KeyboardInterrupt:
            self.log("    Interrupted")
            return 128
        finally:
            signal.signal(signal.SIGALRM, signal.SIG_IGN)
            if p.returncode is None:
                p.terminate()

    def run(self):
        """
        This function is called after initialization to set up and run all doctests.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: from sage.env import SAGE_SRC
            sage: import os
            sage: DD = DocTestDefaults()
            sage: filename = os.path.join(SAGE_SRC, "sage", "sets", "non_negative_integers.py")
            sage: DC = DocTestController(DD, [filename])
            sage: DC.run()
            Running doctests with ID ...
            Doctesting 1 file.
            sage -t .../sage/sets/non_negative_integers.py
                [... tests, ... s]
            ----------------------------------------------------------------------
            All tests passed!
            ----------------------------------------------------------------------
            Total time for all tests: ... seconds
                cpu time: ... seconds
                cumulative wall time: ... seconds
            0
        """
        opt = self.options
        L = (opt.gdb, opt.valgrind, opt.massif, opt.cachegrind, opt.omega)
        if any(L):
            if L.count(True) > 1:
                self.log(
                    "You may only specify one of gdb, valgrind/memcheck, massif, cachegrind, omega"
                )
                return 2
            return self.run_val_gdb()
        else:
            self.test_safe_directory()
            self.create_run_id()
            self.add_files()
            self.expand_files_into_sources()
            self.filter_sources()
            self.sort_sources()
            self.run_doctests()
            return self.reporter.error_status
예제 #12
0
class DocTestController(SageObject):
    """
    This class controls doctesting of files.

    After creating it with appropriate options, call the :meth:`run` method to run the doctests.
    """
    def __init__(self, options, args):
        """
        Initialization.

        INPUT:

        - options -- either options generated from the command line by SAGE_LOCAL/bin/sage-runtests
                     or a DocTestDefaults object (possibly with some entries modified)
        - args -- a list of filenames to doctest

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DC = DocTestController(DocTestDefaults(), [])
            sage: DC
            DocTest Controller
        """
        # First we modify options to take environment variables into
        # account and check compatibility of the user's specified
        # options.
        if options.timeout < 0:
            if options.gdb or options.debug:
                # Interactive debuggers: "infinite" timeout
                options.timeout = 0
            elif options.valgrind or options.massif or options.cachegrind or options.omega:
                # Non-interactive debuggers: 48 hours
                options.timeout = int(os.getenv('SAGE_TIMEOUT_VALGRIND', 48 * 60 * 60))
            elif options.long:
                options.timeout = int(os.getenv('SAGE_TIMEOUT_LONG', 30 * 60))
            else:
                options.timeout = int(os.getenv('SAGE_TIMEOUT', 5 * 60))
        if options.nthreads == 0:
            options.nthreads = int(os.getenv('SAGE_NUM_THREADS_PARALLEL',1))
        if options.failed and not (args or options.new or options.sagenb):
            # If the user doesn't specify any files then we rerun all failed files.
            options.all = True
        if options.global_iterations == 0:
            options.global_iterations = int(os.environ.get('SAGE_TEST_GLOBAL_ITER', 1))
        if options.file_iterations == 0:
            options.file_iterations = int(os.environ.get('SAGE_TEST_ITER', 1))
        if options.debug and options.nthreads > 1:
            print("Debugging requires single-threaded operation, setting number of threads to 1.")
            options.nthreads = 1
        if options.serial:
            options.nthreads = 1
        if options.verbose:
            options.show_skipped = True

        if isinstance(options.optional, basestring):
            s = options.optional.lower()
            if s in ['all', 'true']:
                options.optional = True
            else:
                options.optional = set(s.split(','))
                # Check that all tags are valid
                for o in options.optional:
                    if not optionaltag_regex.search(o):
                        raise ValueError('invalid optional tag %s'%repr(o))

        self.options = options
        self.files = args
        if options.logfile:
            try:
                self.logfile = open(options.logfile, 'a')
            except IOError:
                print "Unable to open logfile at %s\nProceeding without logging."%(options.logfile)
                self.logfile = None
        else:
            self.logfile = None
        self.stats = {}
        self.load_stats(options.stats_path)
        self._init_warn_long()

    def _init_warn_long(self):
        """
        Pick a suitable default for the ``--warn-long`` option if not specified.

        It is desirable to have all tests (even ``# long`` ones)
        finish in less than about 5 seconds. Longer tests typically
        don't add coverage, they just make testing slow.

        The default used here is 60 seconds on a modern computer. It
        should eventually be lowered to 5 seconds, but its best to
        boil the frog slowly.

        The stored timings are used to adjust this limit according to
        the machine running the tests.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DC = DocTestController(DocTestDefaults(), [])
            sage: DC.options.warn_long = 5.0
            sage: DC._init_warn_long()
            sage: DC.options.warn_long    # existing command-line options are not changed
            5.00000000000000
        """
        if self.options.warn_long is not None:     # Specified on the command line
            return
        try:
            self.options.warn_long = 60.0 * self.second_on_modern_computer()
        except RuntimeError as err:
            if not sage.doctest.DOCTEST_MODE:
                print(err)   # No usable timing information

    def second_on_modern_computer(self):
        """
        Return the wall time equivalent of a second on a modern computer.

        OUTPUT:

        Float. The wall time on your computer that would be equivalent
        to one second on a modern computer. Unless you have kick-ass
        hardware this should always be >= 1.0. Raises a
        ``RuntimeError`` if there are no stored timings to use as
        benchmark.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DC = DocTestController(DocTestDefaults(), [])
            sage: DC.second_on_modern_computer()   # not tested
        """
        if len(self.stats) == 0:
            raise RuntimeError('no stored timings available')
        success = []
        failed = []
        for mod in self.stats.values():
            if mod.get('failed', False):
                failed.append(mod['walltime'])
            else:
                success.append(mod['walltime'])
        if len(success) < 2500:
            raise RuntimeError('too few successful tests, not using stored timings')
        if len(failed) > 20:
            raise RuntimeError('too many failed tests, not using stored timings')
        expected = 12800.0       # Core i7 Quad-Core 2014
        return sum(success) / expected

    def _repr_(self):
        """
        String representation.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DC = DocTestController(DocTestDefaults(), [])
            sage: repr(DC) # indirect doctest
            'DocTest Controller'
        """
        return "DocTest Controller"

    def load_stats(self, filename):
        """
        Load stats from the most recent run(s).

        Stats are stored as a JSON file, and include information on
        which files failed tests and the walltime used for execution
        of the doctests.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DC = DocTestController(DocTestDefaults(), [])
            sage: import json
            sage: filename = tmp_filename()
            sage: with open(filename, 'w') as stats_file:
            ...       json.dump({'sage.doctest.control':{u'walltime':1.0r}}, stats_file)
            sage: DC.load_stats(filename)
            sage: DC.stats['sage.doctest.control']
            {u'walltime': 1.0}

        If the file doesn't exist, nothing happens. If there is an
        error, print a message. In any case, leave the stats alone::

            sage: d = tmp_dir()
            sage: DC.load_stats(os.path.join(d))  # Cannot read a directory
            Error loading stats from ...
            sage: DC.load_stats(os.path.join(d, "no_such_file"))
            sage: DC.stats['sage.doctest.control']
            {u'walltime': 1.0}
        """
        # Simply ignore non-existing files
        if not os.path.exists(filename):
            return

        try:
            with open(filename) as stats_file:
                self.stats.update(json.load(stats_file))
        except Exception:
            self.log("Error loading stats from %s"%filename)

    def save_stats(self, filename):
        """
        Save stats from the most recent run as a JSON file.

        WARNING: This function overwrites the file.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DC = DocTestController(DocTestDefaults(), [])
            sage: DC.stats['sage.doctest.control'] = {u'walltime':1.0r}
            sage: filename = tmp_filename()
            sage: DC.save_stats(filename)
            sage: import json
            sage: D = json.load(open(filename))
            sage: D['sage.doctest.control']
            {u'walltime': 1.0}
        """
        from sage.misc.temporary_file import atomic_write
        with atomic_write(filename) as stats_file:
            json.dump(self.stats, stats_file)


    def log(self, s, end="\n"):
        """
        Logs the string ``s + end`` (where ``end`` is a newline by default)
        to the logfile and prints it to the standard output.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DD = DocTestDefaults(logfile=tmp_filename())
            sage: DC = DocTestController(DD, [])
            sage: DC.log("hello world")
            hello world
            sage: DC.logfile.close()
            sage: print open(DD.logfile).read()
            hello world

        Check that no duplicate logs appear, even when forking (:trac:`15244`)::

            sage: DD = DocTestDefaults(logfile=tmp_filename())
            sage: DC = DocTestController(DD, [])
            sage: DC.log("hello world")
            hello world
            sage: if os.fork() == 0:
            ....:     DC.logfile.close()
            ....:     os._exit(0)
            sage: DC.logfile.close()
            sage: print open(DD.logfile).read()
            hello world

        """
        s += end
        if self.logfile is not None:
            self.logfile.write(s)
            self.logfile.flush()
        sys.stdout.write(s)
        sys.stdout.flush()

    def test_safe_directory(self, dir=None):
        """
        Test that the given directory is safe to run Python code from.

        We use the check added to Python for this, which gives a
        warning when the current directory is considered unsafe.  We promote
        this warning to an error with ``-Werror``.  See
        ``sage/tests/cmdline.py`` for a doctest that this works, see
        also :trac:`13579`.

        TESTS::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DD = DocTestDefaults()
            sage: DC = DocTestController(DD, [])
            sage: DC.test_safe_directory()
            sage: d = os.path.join(tmp_dir(), "test")
            sage: os.mkdir(d)
            sage: os.chmod(d, 0o777)
            sage: DC.test_safe_directory(d)
            Traceback (most recent call last):
            ...
            RuntimeError: refusing to run doctests...
        """
        import subprocess
        with open(os.devnull, 'w') as dev_null:
            if subprocess.call(['python', '-Werror', '-c', ''],
                    stdout=dev_null, stderr=dev_null, cwd=dir) != 0:
                raise RuntimeError(
                      "refusing to run doctests from the current "
                      "directory '{}' since untrusted users could put files in "
                      "this directory, making it unsafe to run Sage code from"
                      .format(os.getcwd()))

    def create_run_id(self):
        """
        Creates the run id.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DC = DocTestController(DocTestDefaults(), [])
            sage: DC.create_run_id()
            Running doctests with ID ...
        """
        self.run_id = time.strftime('%Y-%m-%d-%H-%M-%S-') + "%08x" % random.getrandbits(32)
        from sage.version import version
        self.log("Running doctests with ID %s."%self.run_id)

    def add_files(self):
        r"""
        Checks for the flags '--all', '--new' and '--sagenb'.

        For each one present, this function adds the appropriate directories and files to the todo list.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: from sage.env import SAGE_SRC
            sage: import os
            sage: log_location = os.path.join(SAGE_TMP, 'control_dt_log.log')
            sage: DD = DocTestDefaults(all=True, logfile=log_location)
            sage: DC = DocTestController(DD, [])
            sage: DC.add_files()
            Doctesting entire Sage library.
            sage: os.path.join(SAGE_SRC, 'sage') in DC.files
            True

        ::

            sage: DD = DocTestDefaults(new = True)
            sage: DC = DocTestController(DD, [])
            sage: DC.add_files()
            Doctesting ...

        ::

            sage: DD = DocTestDefaults(sagenb = True)
            sage: DC = DocTestController(DD, [])
            sage: DC.add_files()
            Doctesting the Sage notebook.
            sage: DC.files[0][-6:]
            'sagenb'
        """
        opj = os.path.join
        from sage.env import SAGE_SRC, SAGE_ROOT
        def all_files():
            from glob import glob
            self.files.append(opj(SAGE_SRC, 'sage'))
            self.files.append(opj(SAGE_SRC, 'sage_setup'))
            self.files.append(opj(SAGE_SRC, 'doc', 'common'))
            self.files.extend(glob(opj(SAGE_SRC, 'doc', '[a-z][a-z]')))
            self.options.sagenb = True
        DOT_GIT= opj(SAGE_ROOT, '.git')
        have_git = os.path.exists(DOT_GIT)
        if self.options.all or (self.options.new and not have_git):
            self.log("Doctesting entire Sage library.")
            all_files()
        elif self.options.new and have_git:
            # Get all files changed in the working repo.
            self.log("Doctesting files changed since last git commit")
            import subprocess
            change = subprocess.check_output(["git",
                                              "--git-dir=" + DOT_GIT,
                                              "--work-tree=" + SAGE_ROOT,
                                              "status",
                                              "--porcelain"])
            for line in change.split("\n"):
                if not line:
                    continue
                data = line.strip().split(' ')
                status, filename = data[0], data[-1]
                if (set(status).issubset("MARCU")
                    and filename.startswith("src/sage")
                    and (filename.endswith(".py") or filename.endswith(".pyx"))):
                    self.files.append(os.path.relpath(opj(SAGE_ROOT,filename)))
        if self.options.sagenb:
            if not self.options.all:
                self.log("Doctesting the Sage notebook.")
            from pkg_resources import Requirement, working_set
            sagenb_loc = working_set.find(Requirement.parse('sagenb')).location
            self.files.append(opj(sagenb_loc, 'sagenb'))

    def expand_files_into_sources(self):
        r"""
        Expands ``self.files``, which may include directories, into a
        list of :class:`sage.doctest.FileDocTestSource`

        This function also handles the optional command line option.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: from sage.env import SAGE_SRC
            sage: import os
            sage: dirname = os.path.join(SAGE_SRC, 'sage', 'doctest')
            sage: DD = DocTestDefaults(optional='all')
            sage: DC = DocTestController(DD, [dirname])
            sage: DC.expand_files_into_sources()
            sage: len(DC.sources)
            10
            sage: DC.sources[0].options.optional
            True

        ::

            sage: DD = DocTestDefaults(optional='magma,guava')
            sage: DC = DocTestController(DD, [dirname])
            sage: DC.expand_files_into_sources()
            sage: sorted(list(DC.sources[0].options.optional))
            ['guava', 'magma']

        We check that files are skipped appropriately::

            sage: dirname = tmp_dir()
            sage: filename = os.path.join(dirname, 'not_tested.py')
            sage: with open(filename, 'w') as F:
            ....:     F.write("#"*80 + "\n\n\n\n## nodoctest\n    sage: 1+1\n    4")
            sage: DC = DocTestController(DD, [dirname])
            sage: DC.expand_files_into_sources()
            sage: DC.sources
            []

        The directory ``sage/doctest/tests`` contains ``nodoctest.py``
        but the files should still be tested when that directory is
        explicitly given (as opposed to being recursed into)::

            sage: DC = DocTestController(DD, [os.path.join(SAGE_SRC, 'sage', 'doctest', 'tests')])
            sage: DC.expand_files_into_sources()
            sage: len(DC.sources) >= 10
            True
        """
        def expand():
            for path in self.files:
                if os.path.isdir(path):
                    for root, dirs, files in os.walk(path):
                        for dir in list(dirs):
                            if dir[0] == "." or skipdir(os.path.join(root,dir)):
                                dirs.remove(dir)
                        for file in files:
                            if not skipfile(os.path.join(root,file)):
                                yield os.path.join(root, file)
                else:
                    # the user input this file explicitly, so we don't skip it
                    yield path
        self.sources = [FileDocTestSource(path, self.options) for path in expand()]

    def filter_sources(self):
        """

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: from sage.env import SAGE_SRC
            sage: import os
            sage: dirname = os.path.join(SAGE_SRC, 'sage', 'doctest')
            sage: DD = DocTestDefaults(failed=True)
            sage: DC = DocTestController(DD, [dirname])
            sage: DC.expand_files_into_sources()
            sage: for i, source in enumerate(DC.sources):
            ...       DC.stats[source.basename] = {'walltime': 0.1*(i+1)}
            sage: DC.stats['sage.doctest.control'] = {'failed':True,'walltime':1.0}
            sage: DC.filter_sources()
            Only doctesting files that failed last test.
            sage: len(DC.sources)
            1
        """
        # Filter the sources to only include those with failing doctests if the --failed option is passed
        if self.options.failed:
            self.log("Only doctesting files that failed last test.")
            def is_failure(source):
                basename = source.basename
                return basename not in self.stats or self.stats[basename].get('failed')
            self.sources = [x for x in self.sources if is_failure(x)]

    def sort_sources(self):
        r"""
        This function sorts the sources so that slower doctests are run first.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: from sage.env import SAGE_SRC
            sage: import os
            sage: dirname = os.path.join(SAGE_SRC, 'sage', 'doctest')
            sage: DD = DocTestDefaults(nthreads=2)
            sage: DC = DocTestController(DD, [dirname])
            sage: DC.expand_files_into_sources()
            sage: DC.sources.sort(key=lambda s:s.basename)
            sage: for i, source in enumerate(DC.sources):
            ...       DC.stats[source.basename] = {'walltime': 0.1*(i+1)}
            sage: DC.sort_sources()
            Sorting sources by runtime so that slower doctests are run first....
            sage: print "\n".join([source.basename for source in DC.sources])
            sage.doctest.util
            sage.doctest.test
            sage.doctest.sources
            sage.doctest.reporting
            sage.doctest.parsing
            sage.doctest.forker
            sage.doctest.fixtures
            sage.doctest.control
            sage.doctest.all
            sage.doctest
        """
        if self.options.nthreads > 1 and len(self.sources) > self.options.nthreads:
            self.log("Sorting sources by runtime so that slower doctests are run first....")
            default = dict(walltime=0)
            def sort_key(source):
                basename = source.basename
                return -self.stats.get(basename, default).get('walltime'), basename
            self.sources = [x[1] for x in sorted((sort_key(source), source) for source in self.sources)]

    def run_doctests(self):
        """
        Actually runs the doctests.

        This function is called by :meth:`run`.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: from sage.env import SAGE_SRC
            sage: import os
            sage: dirname = os.path.join(SAGE_SRC, 'sage', 'rings', 'homset.py')
            sage: DD = DocTestDefaults()
            sage: DC = DocTestController(DD, [dirname])
            sage: DC.expand_files_into_sources()
            sage: DC.run_doctests()
            Doctesting 1 file.
            sage -t .../sage/rings/homset.py
                [... tests, ... s]
            ----------------------------------------------------------------------
            All tests passed!
            ----------------------------------------------------------------------
            Total time for all tests: ... seconds
                cpu time: ... seconds
                cumulative wall time: ... seconds
        """
        nfiles = 0
        nother = 0
        for F in self.sources:
            if isinstance(F, FileDocTestSource):
                nfiles += 1
            else:
                nother += 1
        if self.sources:
            filestr = ", ".join(([count_noun(nfiles, "file")] if nfiles else []) +
                                ([count_noun(nother, "other source")] if nother else []))
            threads = " using %s threads"%(self.options.nthreads) if self.options.nthreads > 1 else ""
            iterations = []
            if self.options.global_iterations > 1:
                iterations.append("%s global iterations"%(self.options.global_iterations))
            if self.options.file_iterations > 1:
                iterations.append("%s file iterations"%(self.options.file_iterations))
            iterations = ", ".join(iterations)
            if iterations:
                iterations = " (%s)"%(iterations)
            self.log("Doctesting %s%s%s."%(filestr, threads, iterations))
            self.reporter = DocTestReporter(self)
            self.dispatcher = DocTestDispatcher(self)
            N = self.options.global_iterations
            for it in range(N):
                try:
                    self.timer = Timer().start()
                    self.dispatcher.dispatch()
                except KeyboardInterrupt:
                    it = N - 1
                    break
                finally:
                    self.timer.stop()
                    self.reporter.finalize()
                    self.cleanup(it == N - 1)
        else:
            self.log("No files to doctest")
            self.reporter = DictAsObject(dict(error_status=0))

    def cleanup(self, final=True):
        """
        Runs cleanup activities after actually running doctests.

        In particular, saves the stats to disk and closes the logfile.

        INPUT:

        - ``final`` -- whether to close the logfile

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: from sage.env import SAGE_SRC
            sage: import os
            sage: dirname = os.path.join(SAGE_SRC, 'sage', 'rings', 'infinity.py')
            sage: DD = DocTestDefaults()

            sage: DC = DocTestController(DD, [dirname])
            sage: DC.expand_files_into_sources()
            sage: DC.sources.sort(key=lambda s:s.basename)

            sage: for i, source in enumerate(DC.sources):
            ....:     DC.stats[source.basename] = {'walltime': 0.1*(i+1)}
            ....:

            sage: DC.run()
            Running doctests with ID ...
            Doctesting 1 file.
            sage -t .../rings/infinity.py
                [... tests, ... s]
            ----------------------------------------------------------------------
            All tests passed!
            ----------------------------------------------------------------------
            Total time for all tests: ... seconds
                cpu time: ... seconds
                cumulative wall time: ... seconds
            0
            sage: DC.cleanup()
        """
        self.stats.update(self.reporter.stats)
        self.save_stats(self.options.stats_path)
        # Close the logfile
        if final and self.logfile is not None:
            self.logfile.close()
            self.logfile = None

    def _assemble_cmd(self):
        """
        Assembles a shell command used in running tests under gdb or valgrind.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DC = DocTestController(DocTestDefaults(timeout=123), ["hello_world.py"])
            sage: print DC._assemble_cmd()
            python "$SAGE_LOCAL/bin/sage-runtests" --serial --timeout=123 hello_world.py
        """
        cmd = '''python "%s" --serial '''%(os.path.join("$SAGE_LOCAL","bin","sage-runtests"))
        opt = dict_difference(self.options.__dict__, DocTestDefaults().__dict__)
        for o in ("all", "sagenb"):
            if o in opt:
                raise ValueError("You cannot run gdb/valgrind on the whole sage%s library"%("" if o == "all" else "nb"))
        for o in ("all", "sagenb", "long", "force_lib", "verbose", "failed", "new"):
            if o in opt:
                cmd += "--%s "%o
        for o in ("timeout", "optional", "randorder", "stats_path"):
            if o in opt:
                cmd += "--%s=%s "%(o, opt[o])
        return cmd + " ".join(self.files)

    def run_val_gdb(self, testing=False):
        """
        Spawns a subprocess to run tests under the control of gdb or valgrind.

        INPUT:

        - ``testing`` -- boolean; if True then the command to be run
          will be printed rather than a subprocess started.

        EXAMPLES:

        Note that the command lines include unexpanded environment
        variables. It is safer to let the shell expand them than to
        expand them here and risk insufficient quoting. ::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DD = DocTestDefaults(gdb=True)
            sage: DC = DocTestController(DD, ["hello_world.py"])
            sage: DC.run_val_gdb(testing=True)
            exec gdb -x "$SAGE_LOCAL/bin/sage-gdb-commands" --args python "$SAGE_LOCAL/bin/sage-runtests" --serial --timeout=0 hello_world.py

        ::

            sage: DD = DocTestDefaults(valgrind=True, optional="all", timeout=172800)
            sage: DC = DocTestController(DD, ["hello_world.py"])
            sage: DC.run_val_gdb(testing=True)
            exec valgrind --tool=memcheck --leak-resolution=high --leak-check=full --num-callers=25 --suppressions="$SAGE_LOCAL/lib/valgrind/sage.supp"  --log-file=".../valgrind/sage-memcheck.%p" python "$SAGE_LOCAL/bin/sage-runtests" --serial --timeout=172800 --optional=True hello_world.py
        """
        try:
            sage_cmd = self._assemble_cmd()
        except ValueError:
            self.log(sys.exc_info()[1])
            return 2
        opt = self.options
        if opt.gdb:
            cmd = '''exec gdb -x "$SAGE_LOCAL/bin/sage-gdb-commands" --args '''
            flags = ""
            if opt.logfile:
                sage_cmd += " --logfile %s"%(opt.logfile)
        else:
            if opt.logfile is None:
                default_log = os.path.join(DOT_SAGE, "valgrind")
                if os.path.exists(default_log):
                    if not os.path.isdir(default_log):
                        self.log("%s must be a directory"%default_log)
                        return 2
                else:
                    os.makedirs(default_log)
                logfile = os.path.join(default_log, "sage-%s")
            else:
                logfile = opt.logfile
            if opt.valgrind:
                toolname = "memcheck"
                flags = os.getenv("SAGE_MEMCHECK_FLAGS")
                if flags is None:
                    flags = "--leak-resolution=high --leak-check=full --num-callers=25 "
                    flags += '''--suppressions="%s" '''%(os.path.join("$SAGE_LOCAL","lib","valgrind","sage.supp"))
            elif opt.massif:
                toolname = "massif"
                flags = os.getenv("SAGE_MASSIF_FLAGS", "--depth=6 ")
            elif opt.cachegrind:
                toolname = "cachegrind"
                flags = os.getenv("SAGE_CACHEGRIND_FLAGS", "")
            elif opt.omega:
                toolname = "exp-omega"
                flags = os.getenv("SAGE_OMEGA_FLAGS", "")
            cmd = "exec valgrind --tool=%s "%(toolname)
            flags += ''' --log-file="%s" ''' % logfile
            if opt.omega:
                toolname = "omega"
            if "%s" in flags:
                flags %= toolname + ".%p" # replace %s with toolname
        cmd += flags + sage_cmd

        self.log(cmd)
        sys.stdout.flush()
        sys.stderr.flush()
        if self.logfile is not None:
            self.logfile.flush()

        if testing:
            return

        # Setup Sage signal handler
        init_interrupts()

        import signal, subprocess
        p = subprocess.Popen(cmd, shell=True)
        if opt.timeout > 0:
            signal.alarm(opt.timeout)
        try:
            return p.wait()
        except AlarmInterrupt:
            self.log("    Timed out")
            return 4
        except KeyboardInterrupt:
            self.log("    Interrupted")
            return 128
        finally:
            signal.alarm(0)
            if p.returncode is None:
                p.terminate()

    def run(self):
        """
        This function is called after initialization to set up and run all doctests.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: from sage.env import SAGE_SRC
            sage: import os
            sage: DD = DocTestDefaults()
            sage: filename = os.path.join(SAGE_SRC, "sage", "sets", "non_negative_integers.py")
            sage: DC = DocTestController(DD, [filename])
            sage: DC.run()
            Running doctests with ID ...
            Doctesting 1 file.
            sage -t .../sage/sets/non_negative_integers.py
                [... tests, ... s]
            ----------------------------------------------------------------------
            All tests passed!
            ----------------------------------------------------------------------
            Total time for all tests: ... seconds
                cpu time: ... seconds
                cumulative wall time: ... seconds
            0
        """
        opt = self.options
        L = (opt.gdb, opt.valgrind, opt.massif, opt.cachegrind, opt.omega)
        if any(L):
            if L.count(True) > 1:
                self.log("You may only specify one of gdb, valgrind/memcheck, massif, cachegrind, omega")
                return 2
            return self.run_val_gdb()
        else:
            self.test_safe_directory()
            self.create_run_id()
            from sage.env import SAGE_ROOT
            DOT_GIT= os.path.join(SAGE_ROOT, '.git')
            if os.path.isdir(DOT_GIT):
                import subprocess
                try:
                    branch = subprocess.check_output(["git",
                                                      "--git-dir=" + DOT_GIT,
                                                      "rev-parse",
                                                      "--abbrev-ref",
                                                      "HEAD"])
                    self.log("Git branch: " + branch, end="")
                except subprocess.CalledProcessError:
                    pass
            self.add_files()
            self.expand_files_into_sources()
            self.filter_sources()
            self.sort_sources()
            self.run_doctests()
            return self.reporter.error_status
예제 #13
0
full_timer.start()
result = pnr.optimize_design(
    optimizer,
    tester.init,
    tester.funcs,
    verbose=False,
    cutoff=cutoff,
    build_timer=build_timer,
    solve_timer=solve_timer,
    return_bounds=True,
    optimize_final=optimize_final,
    #        attest_func=modeler.model_checker,
)

full_timer.stop()

print(
    json.dumps({
        'benchmark': {
            'fabric': fabric_file,
            'contexts': contexts,
            'design': design_file,
        },
        'params': {
            'incremental': incremental,
            'cutoff': cutoff,
            'optimize_final': optimize_final,
            'optimizer': optimizer_name,
            'duplicate_const': duplicate_const,
            'duplicate_all': duplicate_all,
예제 #14
0
class DigsbyConnect(TimeoutSocketOne):
    _SERVERTIMEOUT = 8

    def stale_connection(self):

        if getattr(self, '_triumphant', False):
            log.info('stale_connection was called but i already won! yayayay')
        else:
            log.info(
                '%r had a stale connection. Calling do_fail (%r) with a connlost error',
                self, self.do_fail)
            self.do_fail(DigsbyLoginError('connlost'))

    def succ(self):
        generator = self.do_login()

        self._timeouttimer = Timer(self._SERVERTIMEOUT, self.stale_connection)
        self._timeouttimer.start()
        self.run_sequence(generator)

    @lock
    def handle_error(self, e=None):
        if hasattr(self, '_timeouttimer'):
            self._timeouttimer.stop()
        TimeoutSocketOne.handle_error(self)

    @lock
    def handle_expt(self):
        if hasattr(self, '_timeouttimer'):
            self._timeouttimer.stop()
        TimeoutSocketOne.handle_expt(self)

    @lock
    def handle_close(self):
        if hasattr(self, '_timeouttimer'):
            self._timeouttimer.stop()
        TimeoutSocketOne.handle_close(self)

    def do_login(self):
        login_str = make_pstring(self.cid) + make_pstring(
            self.un) + make_pstring(self.password)
        codelen = yield (4, login_str)
        codelen = unpack('!I', codelen)[0]
        if codelen <= 0:
            raise DigsbyLoginError('client')
        code = yield (codelen, '')

        try:
            if code == 'success':
                cookielen = unpack('!I', (yield (4, '')))[0]
                cookie = yield (cookielen, '')
                log.debug('Got cookie: %r', cookie)
                serverslen = unpack('!I', (yield (4, '')))[0]
                servers = yield (serverslen, '')
                log.debug('Got servers: %r', servers)
                servers = servers.split(' ')
                self.cookie = cookie
                self.servers = servers
                self._triumphant = True
                return
            elif code == 'error':
                log.debug('Got error!')
                reasonlen = unpack('!I', (yield (4, '')))[0]
                reason = yield (reasonlen, '')
                log.debug('Got error reason: %r', reason)
                raise DigsbyLoginError(reason)
            else:
                log.debug('Unknown error occurred! blaming the client!')
                raise DigsbyLoginError('client')
        except DigsbyLoginError, e:
            if e.reason == 'server':
                log.debug('Got "upgrading digsby" error code. Sleeping.')
                import time
                time.sleep(POLL_SLEEP_TIME)
            raise e
        except Exception, e:
            print_exc()
            raise DigsbyLoginError('client')
예제 #15
0
class DigsbyConnect(TimeoutSocketOne):
    _SERVERTIMEOUT = 8

    def stale_connection(self):

        if getattr(self, '_triumphant', False):
            log.info('stale_connection was called but i already won! yayayay')
        else:
            log.info('%r had a stale connection. Calling do_fail (%r) with a connlost error', self, self.do_fail)
            self.do_fail(DigsbyLoginError('connlost'))

    def succ(self):
        generator = self.do_login()

        self._timeouttimer = Timer(self._SERVERTIMEOUT, self.stale_connection)
        self._timeouttimer.start()
        self.run_sequence( generator )

    @lock
    def handle_error(self, e=None):
        if hasattr(self, '_timeouttimer'):
            self._timeouttimer.stop()
        TimeoutSocketOne.handle_error(self)

    @lock
    def handle_expt(self):
        if hasattr(self, '_timeouttimer'):
            self._timeouttimer.stop()
        TimeoutSocketOne.handle_expt(self)

    @lock
    def handle_close(self):
        if hasattr(self, '_timeouttimer'):
            self._timeouttimer.stop()
        TimeoutSocketOne.handle_close(self)

    def do_login(self):
        login_str = make_pstring(self.cid) + make_pstring(self.un) + make_pstring(self.password)
        codelen = yield (4, login_str)
        codelen = unpack('!I', codelen)[0]
        if codelen <= 0:
            raise DigsbyLoginError('client')
        code = yield (codelen, '')

        try:
            if code == 'success':
                cookielen = unpack('!I', (yield (4, '')))[0]
                cookie = yield (cookielen, '')
                log.debug('Got cookie: %r', cookie)
                serverslen = unpack('!I', (yield (4, '')))[0]
                servers = yield (serverslen, '')
                log.debug('Got servers: %r', servers)
                servers = servers.split(' ')
                self.cookie  = cookie
                self.servers = servers
                self._triumphant = True
                return
            elif code == 'error':
                log.debug('Got error!')
                reasonlen = unpack('!I', (yield (4, '')))[0]
                reason = yield (reasonlen, '')
                log.debug('Got error reason: %r', reason)
                raise DigsbyLoginError(reason)
            else:
                log.debug('Unknown error occurred! blaming the client!')
                raise DigsbyLoginError('client')
        except DigsbyLoginError, e:
            if e.reason == 'server':
                log.debug('Got "upgrading digsby" error code. Sleeping.')
                import time; time.sleep(POLL_SLEEP_TIME)
            raise e
        except Exception, e:
            print_exc()
            raise DigsbyLoginError('client')
예제 #16
0
def main():
    timer = Timer()
    timer.start()

    os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
    tf.set_random_seed(0)

    MAX_SENT_LENGTH = 20
    MAX_SENTS = 100
    EMBEDDING_DIM = 50
    POST_DIM = 10
    TEXT_DIM = 50
    VALIDATION_SPLIT = 0.2
    MIXTURES = 5
    Graph_DIM = 10
    TRAINING_EPOCHS = 50

    flags = tf.app.flags
    FLAGS = flags.FLAGS
    flags.DEFINE_float('learning_rate', 0.0001, 'Initial learning rate.')
    flags.DEFINE_integer('hidden1', 32, 'Number of units in hidden layer 1.')
    flags.DEFINE_integer('hidden2', Graph_DIM,
                         'Number of units in hidden layer 2.')
    flags.DEFINE_integer('batch_size', 32, 'Size of a mini-batch')
    flags.DEFINE_float('dropout', 0., 'Dropout rate (1 - keep probability).')
    flags.DEFINE_float('lambda1', 1e-4, 'Parameter of energy.')
    flags.DEFINE_float('lambda2', 1e-9, 'lossSigma.')
    flags.DEFINE_float('lambda3', 0.01, 'GAE.')
    flags.DEFINE_string('model', 'gcn_ae', 'Model string.')
    model_str = FLAGS.model

    # variable to store evaluation results
    precision_list = []
    recall_list = []
    f1_list = []
    auc_list = []

    for t in range(10):
        with open('./data/instagram.pickle', 'rb') as handle:
            store_data = pickle.load(handle)

        labels = store_data['labels']
        df = store_data['df']
        data = store_data['data']
        postInfo = store_data['postInfo']
        timeInfo = store_data['timeInfo']
        embedding_matrix = store_data['embedding_matrix']
        word_index = store_data['word_index']

        num_session = data.shape[0]
        nb_validation_samples = int(VALIDATION_SPLIT * num_session)
        '''For Evaluation'''
        single_label = np.asarray(labels)
        labels = to_categorical(np.asarray(labels))
        print('Shape of data tensor:', data.shape)
        print('Shape of label tensor:', labels.shape)

        zeros = np.zeros(num_session)
        zeros = zeros.reshape((num_session, 1, 1))
        # FLAGS.learning_rate = lr
        '''Hierarchical Attention Network for text and other info'''
        placeholders = {
            'zero_input':
            tf.placeholder(tf.float32, shape=[None, 1, 1]),
            'review_input':
            tf.placeholder(tf.float32,
                           shape=[None, MAX_SENTS, MAX_SENT_LENGTH + 1]),
            'post_input':
            tf.placeholder(tf.float32, shape=[
                None,
                4,
            ]),
            'time_label':
            tf.placeholder(tf.float32, shape=[None, MAX_SENTS])
        }

        g = nx.Graph()
        edgelist = pd.read_csv('./data/source_target.csv')
        for i, elrow in edgelist.iterrows():
            g.add_edge(elrow[0].strip('\n'), elrow[1].strip('\n'))
        adj = nx.adjacency_matrix(g)
        user_attributes = pd.read_csv('./data/user_friend_follower.csv')
        user_attributes = user_attributes.set_index('user').T.to_dict('list')
        nodelist = list(g.nodes())
        features = []
        User_post = np.zeros(
            (len(nodelist), num_session))  # 2218 number of posts

        for id, node in enumerate(nodelist):
            posts_ID = df.loc[df['owner_id'] == node].index.values.tolist()
            for p_id in posts_ID:
                User_post[id][p_id] = 1
            features.append(user_attributes[node])

        # only keep the posts that are in the training data
        User_post_train = User_post[:, :-nb_validation_samples]
        User_post_test = User_post[:, -nb_validation_samples:]
        features = sparse.csr_matrix(features)
        features = normalize(features, norm='max', axis=0)
        adj_orig = adj
        adj_orig = adj_orig - sparse.dia_matrix(
            (adj_orig.diagonal()[np.newaxis, :], [0]), shape=adj_orig.shape)
        adj_orig.eliminate_zeros()
        adj_norm = preprocess_graph(adj)
        adj_label = adj + sparse.eye(adj.shape[0])
        adj_label = sparse_to_tuple(adj_label)

        # Define placeholders
        placeholders.setdefault('features', tf.sparse_placeholder(tf.float32))
        placeholders.setdefault('adj', tf.sparse_placeholder(tf.float32))
        placeholders.setdefault('adj_orig', tf.sparse_placeholder(tf.float32))
        placeholders.setdefault('dropout',
                                tf.placeholder_with_default(0., shape=()))
        placeholders.setdefault(
            'user_post', tf.placeholder(tf.int32, [len(nodelist), None]))
        d = {placeholders['dropout']: FLAGS.dropout}
        placeholders.update(d)
        num_nodes = adj.shape[0]

        features = sparse_to_tuple(features.tocoo())
        num_features = features[2][1]
        features_nonzero = features[1].shape[0]
        '''Graph AutoEncoder'''
        if model_str == 'gcn_ae':
            Graph_model = GCNModelAE(placeholders, num_features,
                                     features_nonzero)
        elif model_str == 'gcn_vae':
            Graph_model = GCNModelVAE(placeholders, num_features, num_nodes,
                                      features_nonzero)

        embedding_layer = Embedding(len(word_index) + 1,
                                    EMBEDDING_DIM,
                                    weights=[embedding_matrix],
                                    input_length=MAX_SENT_LENGTH,
                                    trainable=True,
                                    mask_zero=True)

        all_input = Input(shape=(MAX_SENT_LENGTH + 1, ))
        sentence_input = crop(1, 0, MAX_SENT_LENGTH)(all_input)  # slice
        time_input = crop(1, MAX_SENT_LENGTH,
                          MAX_SENT_LENGTH + 1)(all_input)  # slice
        embedded_sequences = embedding_layer(sentence_input)
        # embedded_sequences=BatchNormalization()(embedded_sequences)
        l_lstm = Bidirectional(GRU(TEXT_DIM,
                                   return_sequences=True))(embedded_sequences)
        l_att = AttLayer(TEXT_DIM)(l_lstm)  # (?,200)
        # time_embedding=Dense(TIME_DIM,activation='sigmoid')(time_input)
        merged_output = Concatenate()([l_att,
                                       time_input])  # text+time information
        sentEncoder = Model(all_input, merged_output)

        review_input = placeholders['review_input']
        review_encoder = TimeDistributed(sentEncoder)(review_input)
        l_lstm_sent = Bidirectional(GRU(TEXT_DIM,
                                        return_sequences=True))(review_encoder)
        fully_sent = Dense(1, use_bias=False)(l_lstm_sent)
        pred_time = Activation(activation='linear')(fully_sent)
        zero_input = placeholders['zero_input']
        shift_predtime = Concatenate(axis=1)([zero_input, pred_time])
        shift_predtime = crop(1, 0, MAX_SENTS)(shift_predtime)
        l_att_sent = AttLayer(TEXT_DIM)(l_lstm_sent)

        # embed the #likes, shares
        post_input = placeholders['post_input']
        fully_post = Dense(POST_DIM, use_bias=False)(post_input)
        # norm_fullypost=BatchNormalization()(fully_post)
        post_embedding = Activation(activation='relu')(fully_post)
        fully_review = concatenate(
            [l_att_sent, post_embedding]
        )  # merge the document level vectro with the additional embedded features such as #likes

        pos_weight = float(adj.shape[0] * adj.shape[0] - adj.sum()) / adj.sum()
        norm = adj.shape[0] * adj.shape[0] / float(
            (adj.shape[0] * adj.shape[0] - adj.sum()) * 2)
        with tf.name_scope('graph_cost'):
            preds_sub = Graph_model.reconstructions
            labels_sub = tf.reshape(
                tf.sparse_tensor_to_dense(placeholders['adj_orig'],
                                          validate_indices=False), [-1])
            if model_str == 'gcn_ae':
                opt = CostAE(preds=preds_sub,
                             labels=labels_sub,
                             pos_weight=pos_weight,
                             norm=norm)
            elif model_str == 'gcn_vae':
                opt = CostVAE(preds=preds_sub,
                              labels=labels_sub,
                              model=Graph_model,
                              num_nodes=num_nodes,
                              pos_weight=pos_weight,
                              norm=norm)
        User_latent = Graph_model.z_mean  # (n_user, G_embeddim)
        Post_latent = fully_review  # (batch size, text_embed_dim+post_dim)
        max_indices = tf.argmax(placeholders['user_post'], axis=0)
        add_latent = tf.gather(User_latent, max_indices)
        session_latent = tf.concat(
            [Post_latent, add_latent],
            axis=1)  # the representation of text + graph
        '''DAGMM'''
        h1_size = 2 * TEXT_DIM + Graph_DIM + POST_DIM
        gmm = GMM(MIXTURES)
        est_net = EstimationNet([h1_size, MIXTURES], tf.nn.tanh)
        gamma = est_net.inference(session_latent, FLAGS.dropout)
        gmm.fit(session_latent, gamma)
        individual_energy = gmm.energy(session_latent)

        Time_label = placeholders['time_label']
        Time_label = tf.reshape(Time_label,
                                [tf.shape(Time_label)[0], MAX_SENTS, 1])

        with tf.name_scope('loss'):
            GAE_error = opt.cost
            energy = tf.reduce_mean(individual_energy)
            lossSigma = gmm.cov_diag_loss()
            prediction_error = tf.losses.mean_squared_error(
                shift_predtime, Time_label)
            loss = prediction_error + FLAGS.lambda1 * energy + FLAGS.lambda2 * lossSigma + FLAGS.lambda3 * GAE_error

        x_train = data[:-nb_validation_samples]
        time_train = timeInfo[:-nb_validation_samples]
        zeros_train = zeros[:-nb_validation_samples]
        y_train = labels[:-nb_validation_samples]
        post_train = postInfo[:-nb_validation_samples]
        x_val = data[-nb_validation_samples:]
        zeros_test = zeros[-nb_validation_samples:]
        time_test = timeInfo[-nb_validation_samples:]
        y_val = labels[-nb_validation_samples:]
        post_test = postInfo[-nb_validation_samples:]
        y_single = single_label[-nb_validation_samples:]

        print(
            'Number of positive and negative posts in training and validation set'
        )
        print(y_train.sum(axis=0))
        print(y_val.sum(axis=0))
        print("model fitting - Unsupervised cyberbullying detection")

        optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
        train_step = optimizer.minimize(loss)
        GAEcorrect_prediction = tf.equal(
            tf.cast(tf.greater_equal(tf.sigmoid(preds_sub), 0.5), tf.int32),
            tf.cast(labels_sub, tf.int32))
        feed_dict_train = construct_feed_dict(zeros_train, x_train, post_train,
                                              time_train, FLAGS.dropout,
                                              adj_norm, adj_label, features,
                                              User_post_train, placeholders)
        feed_dict_train.update({placeholders['dropout']: FLAGS.dropout})

        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        total_batch = int(num_session / FLAGS.batch_size)
        zero_batches = np.array_split(zeros_train, total_batch)
        x_batches = np.array_split(x_train, total_batch)
        p_batches = np.array_split(post_train, total_batch)
        t_batches = np.array_split(time_train, total_batch)
        UP_batches = np.array_split(User_post_train, total_batch, axis=1)

        for epoch in range(TRAINING_EPOCHS):
            ave_cost = 0
            ave_energy = 0
            ave_recon = 0
            ave_sigma = 0
            ave_GAE = 0
            for i in range(total_batch):
                batch_x = x_batches[i]
                batch_p = p_batches[i]
                batch_t = t_batches[i]
                batch_z = zero_batches[i]
                user_post = UP_batches[i]
                feed_dict = construct_feed_dict(batch_z, batch_x, batch_p,
                                                batch_t, FLAGS.dropout,
                                                adj_norm, adj_label, features,
                                                user_post, placeholders)
                feed_dict.update({placeholders['dropout']: FLAGS.dropout})
                _, total_loss, loss_sigma, GAE_loss, Energy_error, recon_error = sess.run(
                    [
                        train_step, loss, lossSigma, GAE_error, energy,
                        prediction_error
                    ], feed_dict)
                ave_cost += total_loss / total_batch
                ave_energy += Energy_error / total_batch
                ave_GAE += GAE_loss / total_batch
                ave_sigma += loss_sigma / total_batch
                ave_recon += recon_error / total_batch
            # if epoch % 10 == 0 or epoch == TRAINING_EPOCHS - 1:
            # print("This is epoch %d, the total loss is %f, energy error is %f, GAE error is %f, sigma error is %f,prediction error is %f") \
            #      % (epoch + 1, ave_cost, ave_energy, ave_GAE, ave_sigma, ave_recon)

        fix = gmm.fix_op()
        sess.run(fix, feed_dict=feed_dict_train)

        feed_dict_test = construct_feed_dict(zeros_test, x_val, post_test,
                                             time_test, FLAGS.dropout,
                                             adj_norm, adj_label, features,
                                             User_post_test, placeholders)
        pred_energy, representations = sess.run(
            [individual_energy, session_latent], feed_dict=feed_dict_test)
        bully_energy_threshold = np.percentile(pred_energy, 65)
        print('the bully energy threshold is : %f' % bully_energy_threshold)
        label_pred = np.where(pred_energy >= bully_energy_threshold, 1, 0)
        print(precision_recall_fscore_support(y_single, label_pred))
        print(accuracy_score(y_single, label_pred))
        print(roc_auc_score(y_single, label_pred))
        tf.reset_default_graph()
        K.clear_session()

        precision_list.append(
            precision_recall_fscore_support(y_single, label_pred)[0][1])
        recall_list.append(
            precision_recall_fscore_support(y_single, label_pred)[1][1])
        f1_list.append(
            precision_recall_fscore_support(y_single, label_pred)[2][1])
        auc_list.append(roc_auc_score(y_single, label_pred))

    print('>>> Evaluation metrics')
    print('>>> precision mean: {0.4f}; precision std: {1:.4f}'.format(
        np.mean(precision_list), np.std(precision_list)))
    print('>>> recall mean: {0.4f}; recall std: {1:.4f}'.format(
        np.mean(recall_list), np.std(recall_list)))
    print('>>> f1 mean: {0.4f}; f1 std: {1:.4f}'.format(
        np.mean(f1_list), np.std(f1_list)))
    print('>>> auc mean: {0.4f}; auc std: {1:.4f}'.format(
        np.mean(auc_list), np.std(auc_list)))

    timer.stop()
예제 #17
0
class DocTestController(SageObject):
    """
    This class controls doctesting of files.

    After creating it with appropriate options, call the :meth:run() method to run the doctests.
    """
    def __init__(self, options, args):
        """
        Initialization.

        INPUT:

        - options -- either options generated from the command line by SAGE_ROOT/local/bin/sage-runtests
                     or a DocTestDefaults object (possibly with some entries modified)
        - args -- a list of filenames to doctest

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DC = DocTestController(DocTestDefaults(), [])
            sage: DC
            DocTest Controller
        """
        # First we modify options to take environment variables into
        # account and check compatibility of the user's specified
        # options.
        if options.timeout < 0:
            if options.gdb or options.debug:
                # Interactive debuggers: "infinite" timeout
                options.timeout = 0
            elif options.valgrind or options.massif or options.cachegrind or options.omega:
                # Non-interactive debuggers: 48 hours
                options.timeout = int(os.getenv('SAGE_TIMEOUT_VALGRIND', 48 * 60 * 60))
            elif options.long:
                options.timeout = int(os.getenv('SAGE_TIMEOUT_LONG', 30 * 60))
            else:
                options.timeout = int(os.getenv('SAGE_TIMEOUT', 5 * 60))
        if options.nthreads == 0:
            options.nthreads = int(os.getenv('SAGE_NUM_THREADS_PARALLEL',1))
        if options.failed and not (args or options.new or options.sagenb):
            # If the user doesn't specify any files then we rerun all failed files.
            options.all = True
        if options.global_iterations == 0:
            options.global_iterations = int(os.environ.get('SAGE_TEST_GLOBAL_ITER', 1))
        if options.file_iterations == 0:
            options.file_iterations = int(os.environ.get('SAGE_TEST_ITER', 1))
        if options.debug and options.nthreads > 1:
            print("Debugging requires single-threaded operation, setting number of threads to 1.")
            options.nthreads = 1
        if options.serial:
            options.nthreads = 1

        self.options = options
        self.files = args
        if options.all and options.logfile is None:
            options.logfile = os.path.join(os.environ['SAGE_TESTDIR'], 'test.log')
        if options.logfile:
            try:
                self.logfile = open(options.logfile, 'a')
            except IOError:
                print "Unable to open logfile at %s\nProceeding without logging."%(options.logfile)
                self.logfile = None
        else:
            self.logfile = None
        self.stats = {}
        self.load_stats(options.stats_path)

    def _repr_(self):
        """
        String representation.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DC = DocTestController(DocTestDefaults(), [])
            sage: repr(DC) # indirect doctest
            'DocTest Controller'
        """
        return "DocTest Controller"

    def load_stats(self, filename):
        """
        Load stats from the most recent run(s).

        Stats are stored as a JSON file, and include information on
        which files failed tests and the walltime used for execution
        of the doctests.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DC = DocTestController(DocTestDefaults(), [])
            sage: import json
            sage: filename = tmp_filename()
            sage: with open(filename, 'w') as stats_file:
            ...       json.dump({'sage.doctest.control':{u'walltime':1.0r}}, stats_file)
            sage: DC.load_stats(filename)
            sage: DC.stats['sage.doctest.control']
            {u'walltime': 1.0}

        If the file doesn't exist, nothing happens. If there is an
        error, print a message. In any case, leave the stats alone::

            sage: d = tmp_dir()
            sage: DC.load_stats(os.path.join(d))  # Cannot read a directory
            Error loading stats from ...
            sage: DC.load_stats(os.path.join(d, "no_such_file"))
            sage: DC.stats['sage.doctest.control']
            {u'walltime': 1.0}
        """
        # Simply ignore non-existing files
        if not os.path.exists(filename):
            return

        try:
            with open(filename) as stats_file:
                self.stats.update(json.load(stats_file))
        except StandardError:
            self.log("Error loading stats from %s"%filename)

    def save_stats(self, filename):
        """
        Save stats from the most recent run as a JSON file.

        WARNING: This function overwrites the file.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DC = DocTestController(DocTestDefaults(), [])
            sage: DC.stats['sage.doctest.control'] = {u'walltime':1.0r}
            sage: filename = tmp_filename()
            sage: DC.save_stats(filename)
            sage: import json
            sage: D = json.load(open(filename))
            sage: D['sage.doctest.control']
            {u'walltime': 1.0}
        """
        with open(filename, 'w') as stats_file:
            json.dump(self.stats, stats_file)

    def log(self, s, end="\n"):
        """
        Logs the string ``s + end`` (where ``end`` is a newline by default)
        to the logfile and prints it to the standard output.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DD = DocTestDefaults(logfile=tmp_filename())
            sage: DC = DocTestController(DD, [])
            sage: DC.log("hello world")
            hello world
            sage: DC.logfile.close()
            sage: with open(DD.logfile) as logger: print logger.read()
            hello world

        """
        s += end
        if self.logfile is not None:
            self.logfile.write(s)
        sys.stdout.write(s)

    def test_safe_directory(self, dir=None):
        """
        Test that the given directory is safe to run Python code from.

        We use the check added to Python for this, which gives a
        warning when the current directory is considered unsafe.  We promote
        this warning to an error with ``-Werror``.  See
        ``sage/tests/cmdline.py`` for a doctest that this works, see
        also :trac:`13579`.

        TESTS::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DD = DocTestDefaults()
            sage: DC = DocTestController(DD, [])
            sage: DC.test_safe_directory()
            sage: d = os.path.join(tmp_dir(), "test")
            sage: os.mkdir(d)
            sage: os.chmod(d, 0o777)
            sage: DC.test_safe_directory(d)
            Traceback (most recent call last):
            ...
            RuntimeError: refusing to run doctests...
        """
        import subprocess
        with open(os.devnull, 'w') as dev_null:
            if subprocess.call(['python', '-Werror', '-c', ''],
                    stdout=dev_null, stderr=dev_null, cwd=dir) != 0:
                raise RuntimeError(
                      "refusing to run doctests from the current "
                      "directory '{}' since untrusted users could put files in "
                      "this directory, making it unsafe to run Sage code from"
                      .format(os.getcwd()))

    def create_run_id(self):
        """
        Creates the run id.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DC = DocTestController(DocTestDefaults(), [])
            sage: DC.create_run_id()
            Running doctests with ID ...
        """
        self.run_id = time.strftime('%Y-%m-%d-%H-%M-%S-') + "%08x" % random.getrandbits(32)
        from sage.version import version
        self.log("Running doctests with ID %s."%self.run_id)

    def add_files(self):
        """
        Checks for the flags '--all', '--new' and '--sagenb'.

        For each one present, this function adds the appropriate directories and files to the todo list.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: import os
            sage: log_location = os.path.join(SAGE_TMP, 'control_dt_log.log')
            sage: DD = DocTestDefaults(all=True, logfile=log_location)
            sage: DC = DocTestController(DD, [])
            sage: DC.add_files()
            Doctesting entire Sage library.
            sage: os.path.join(os.environ['SAGE_ROOT'], 'devel', 'sage', 'sage') in DC.files
            True

        ::

            sage: DD = DocTestDefaults(new = True)
            sage: DC = DocTestController(DD, [])
            sage: DC.add_files()
            Doctesting files changed since last HG commit.
            sage: len(DC.files) == len([L for L in hg_sage('status', interactive=False, debug=False)[0].split('\n') if len(L.split()) ==2 and L.split()[0] in ['M','A']])
            True

        ::

            sage: DD = DocTestDefaults(sagenb = True)
            sage: DC = DocTestController(DD, [])
            sage: DC.add_files()
            Doctesting the Sage notebook.
            sage: DC.files[0][-6:]
            'sagenb'
        """
        opj = os.path.join
        SAGE_ROOT = os.environ['SAGE_ROOT']
        base = opj(SAGE_ROOT, 'devel', 'sage')
        if self.options.all:
            self.log("Doctesting entire Sage library.")
            from glob import glob
            self.files.append(opj(base, 'sage'))
            self.files.append(opj(base, 'doc', 'common'))
            self.files.extend(glob(opj(base, 'doc', '[a-z][a-z]')))
            self.options.sagenb = True
        elif self.options.new:
            self.log("Doctesting files changed since last HG commit.")
            import sage.all_cmdline
            from sage.misc.hg import hg_sage
            for X in hg_sage('status', interactive=False, debug=False)[0].split('\n'):
                tup = X.split()
                if len(tup) != 2: continue
                c, filename = tup
                if c in ['M','A']:
                    filename = opj(os.environ['SAGE_ROOT'], 'devel', 'sage', filename)
                    self.files.append(filename)
        if self.options.sagenb:
            if not self.options.all:
                self.log("Doctesting the Sage notebook.")
            from pkg_resources import Requirement, working_set
            sagenb_loc = working_set.find(Requirement.parse('sagenb')).location
            self.files.append(opj(sagenb_loc, 'sagenb'))

    def expand_files_into_sources(self):
        """
        Expands ``self.files``, which may include directories, into a
        list of :class:`sage.doctest.FileDocTestSource`

        This function also handles the optional command line option.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: import os
            sage: dirname = os.path.join(os.environ['SAGE_ROOT'], 'devel', 'sage', 'sage', 'doctest')
            sage: DD = DocTestDefaults(optional='all')
            sage: DC = DocTestController(DD, [dirname])
            sage: DC.expand_files_into_sources()
            sage: len(DC.sources)
            9
            sage: DC.sources[0].optional
            True

        ::

            sage: DD = DocTestDefaults(optional='magma,guava')
            sage: DC = DocTestController(DD, [dirname])
            sage: DC.expand_files_into_sources()
            sage: sorted(list(DC.sources[0].optional))
            ['guava', 'magma']
        """
        def skipdir(dirname):
            if os.path.exists(os.path.join(dirname, "nodoctest.py")):
                return True
            # Workaround for https://github.com/sagemath/sagenb/pull/84
            if dirname.endswith(os.path.join(os.sep, 'sagenb', 'data')):
                return True
            return False
        def skipfile(filename):
            base, ext = os.path.splitext(filename)
            if ext not in ('.py', '.pyx', '.pxi', '.sage', '.spyx', '.rst', '.tex'):
                return True
            with open(filename) as F:
                return 'nodoctest' in F.read(50)
        def expand():
            for path in self.files:
                if os.path.isdir(path):
                    for root, dirs, files in os.walk(path):
                        for dir in list(dirs):
                            if dir[0] == "." or skipdir(os.path.join(root,dir)):
                                dirs.remove(dir)
                        for file in files:
                            if not skipfile(os.path.join(root,file)):
                                yield os.path.join(root, file)
                else:
                    # the user input this file explicitly, so we don't skip it
                    yield path
        if self.options.optional == 'all':
            optionals = True
        else:
            optionals = set(self.options.optional.lower().split(','))
        self.sources = [FileDocTestSource(path, self.options.force_lib, long=self.options.long, optional=optionals, randorder=self.options.randorder, useabspath=self.options.abspath) for path in expand()]

    def filter_sources(self):
        """
        
        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: import os
            sage: dirname = os.path.join(os.environ['SAGE_ROOT'], 'devel', 'sage', 'sage', 'doctest')
            sage: DD = DocTestDefaults(failed=True)
            sage: DC = DocTestController(DD, [dirname])
            sage: DC.expand_files_into_sources()
            sage: for i, source in enumerate(DC.sources):
            ...       DC.stats[source.basename] = {'walltime': 0.1*(i+1)}
            sage: DC.stats['sage.doctest.control'] = {'failed':True,'walltime':1.0}
            sage: DC.filter_sources()
            Only doctesting files that failed last test.
            sage: len(DC.sources)
            1
        """
        # Filter the sources to only include those with failing doctests if the --failed option is passed
        if self.options.failed:
            self.log("Only doctesting files that failed last test.")
            def is_failure(source):
                basename = source.basename
                return basename not in self.stats or self.stats[basename].get('failed')
            self.sources = filter(is_failure, self.sources)

    def sort_sources(self):
        """
        This function sorts the sources so that slower doctests are run first.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: import os
            sage: dirname = os.path.join(os.environ['SAGE_ROOT'], 'devel', 'sage', 'sage', 'doctest')
            sage: DD = DocTestDefaults(nthreads=2)
            sage: DC = DocTestController(DD, [dirname])
            sage: DC.expand_files_into_sources()
            sage: DC.sources.sort(key=lambda s:s.basename)
            sage: for i, source in enumerate(DC.sources):
            ...       DC.stats[source.basename] = {'walltime': 0.1*(i+1)}
            sage: DC.sort_sources()
            Sorting sources by runtime so that slower doctests are run first....
            sage: print "\n".join([source.basename for source in DC.sources])
            sage.doctest.util
            sage.doctest.test
            sage.doctest.sources
            sage.doctest.reporting
            sage.doctest.parsing
            sage.doctest.forker
            sage.doctest.control
            sage.doctest.all
            sage.doctest
        """
        if self.options.nthreads > 1 and len(self.sources) > self.options.nthreads:
            self.log("Sorting sources by runtime so that slower doctests are run first....")
            default = dict(walltime=0)
            def sort_key(source):
                basename = source.basename
                return -self.stats.get(basename, default).get('walltime'), basename
            self.sources = [x[1] for x in sorted((sort_key(source), source) for source in self.sources)]

    def run_doctests(self):
        """
        Actually runs the doctests.

        This function is called by :meth:run().

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: import os
            sage: dirname = os.path.join(os.environ['SAGE_ROOT'], 'devel', 'sage', 'sage', 'rings', 'homset.py')
            sage: DD = DocTestDefaults()
            sage: DC = DocTestController(DD, [dirname])
            sage: DC.expand_files_into_sources()
            sage: DC.run_doctests()
            Doctesting 1 file.
            sage -t .../sage/rings/homset.py
                [... tests, ... s]
            ----------------------------------------------------------------------
            All tests passed!
            ----------------------------------------------------------------------
            Total time for all tests: ... seconds
                cpu time: ... seconds
                cumulative wall time: ... seconds
        """
        nfiles = 0
        nother = 0
        for F in self.sources:
            if isinstance(F, FileDocTestSource):
                nfiles += 1
            else:
                nother += 1
        if self.sources:
            filestr = ", ".join(([count_noun(nfiles, "file")] if nfiles else []) +
                                ([count_noun(nother, "other source")] if nother else []))
            threads = " using %s threads"%(self.options.nthreads) if self.options.nthreads > 1 else ""
            iterations = []
            if self.options.global_iterations > 1:
                iterations.append("%s global iterations"%(self.options.global_iterations))
            if self.options.file_iterations > 1:
                iterations.append("%s file iterations"%(self.options.file_iterations))
            iterations = ", ".join(iterations)
            if iterations:
                iterations = " (%s)"%(iterations)
            self.log("Doctesting %s%s%s."%(filestr, threads, iterations))
            self.reporter = DocTestReporter(self)
            self.dispatcher = DocTestDispatcher(self)
            N = self.options.global_iterations
            for it in range(N):
                try:
                    self.timer = Timer().start()
                    self.dispatcher.dispatch()
                except KeyboardInterrupt:
                    it = N - 1
                    break
                finally:
                    self.timer.stop()
                    self.reporter.finalize()
                    self.cleanup(it == N - 1)
        else:
            self.log("No files to doctest")
            self.reporter = DictAsObject(dict(error_status=0))

    def cleanup(self, final=True):
        """
        Runs cleanup activities after actually running doctests.

        In particular, saves the stats to disk and closes the logfile.

        INPUT:

        - ``final`` -- whether to close the logfile

        EXAMPLES::

             sage: from sage.doctest.control import DocTestDefaults, DocTestController
             sage: import os
             sage: dirname = os.path.join(os.environ['SAGE_ROOT'], 'devel', 'sage', 'sage', 'rings', 'infinity.py')
             sage: DD = DocTestDefaults()

             sage: DC = DocTestController(DD, [dirname])
             sage: DC.expand_files_into_sources()
             sage: DC.sources.sort(key=lambda s:s.basename)

             sage: for i, source in enumerate(DC.sources):
             ....:     DC.stats[source.basename] = {'walltime': 0.1*(i+1)}
             ....:

             sage: DC.run()
             Running doctests with ID ...
             Doctesting 1 file.
             sage -t .../rings/infinity.py
                 [... tests, ... s]
             ----------------------------------------------------------------------
             All tests passed!
             ----------------------------------------------------------------------
             Total time for all tests: ... seconds
                 cpu time: ... seconds
                 cumulative wall time: ... seconds
             0
             sage: DC.cleanup()
        """
        self.stats.update(self.reporter.stats)
        self.save_stats(self.options.stats_path)
        # Close the logfile
        if final and self.logfile is not None:
            self.logfile.close()
            self.logfile = None

    def _assemble_cmd(self):
        """
        Assembles a shell command used in running tests under gdb or valgrind.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DC = DocTestController(DocTestDefaults(timeout=123), ["hello_world.py"])
            sage: print DC._assemble_cmd()
            python "$SAGE_LOCAL/bin/sage-runtests" --serial --timeout=123 hello_world.py
        """
        cmd = '''python "%s" --serial '''%(os.path.join("$SAGE_LOCAL","bin","sage-runtests"))
        opt = dict_difference(self.options.__dict__, DocTestDefaults().__dict__)
        for o in ("all", "sagenb"):
            if o in opt:
                raise ValueError("You cannot run gdb/valgrind on the whole sage%s library"%("" if o == "all" else "nb"))
        for o in ("all", "sagenb", "long", "force_lib", "verbose", "failed", "new"):
            if o in opt:
                cmd += "--%s "%o
        for o in ("timeout", "optional", "randorder", "stats_path"):
            if o in opt:
                cmd += "--%s=%s "%(o, opt[o])
        return cmd + " ".join(self.files)

    def run_val_gdb(self, testing=False):
        """
        Spawns a subprocess to run tests under the control of gdb or valgrind.

        INPUT:

        - ``testing`` -- boolean; if True then the command to be run
          will be printed rather than a subprocess started.

        EXAMPLES:

        Note that the command lines include unexpanded environment
        variables. It is safer to let the shell expand them than to
        expand them here and risk insufficient quoting. ::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: DD = DocTestDefaults(gdb=True)
            sage: DC = DocTestController(DD, ["hello_world.py"])
            sage: DC.run_val_gdb(testing=True)
            exec gdb -x "$SAGE_LOCAL/bin/sage-gdb-commands" --args python "$SAGE_LOCAL/bin/sage-runtests" --serial --timeout=0 hello_world.py

        ::

            sage: DD = DocTestDefaults(valgrind=True, optional="all", timeout=172800)
            sage: DC = DocTestController(DD, ["hello_world.py"])
            sage: DC.run_val_gdb(testing=True)
            exec valgrind --tool=memcheck --leak-resolution=high --leak-check=full --num-callers=25 --suppressions="$SAGE_LOCAL/lib/valgrind/sage.supp"  --log-file=".../valgrind/sage-memcheck.%p" python "$SAGE_LOCAL/bin/sage-runtests" --serial --timeout=172800 --optional=all hello_world.py
        """
        try:
            sage_cmd = self._assemble_cmd()
        except ValueError:
            self.log(sys.exc_info()[1])
            return 2
        opt = self.options
        if opt.gdb:
            cmd = '''exec gdb -x "$SAGE_LOCAL/bin/sage-gdb-commands" --args '''
            flags = ""
            if opt.logfile:
                sage_cmd += " --logfile %s"%(opt.logfile)
        else:
            if opt.logfile is None:
                default_log = os.path.join(DOT_SAGE, "valgrind")
                if os.path.exists(default_log):
                    if not os.path.isdir(default_log):
                        self.log("%s must be a directory"%default_log)
                        return 2
                else:
                    os.makedirs(default_log)
                logfile = os.path.join(default_log, "sage-%s")
            else:
                logfile = opt.logfile
            if opt.valgrind:
                toolname = "memcheck"
                flags = os.getenv("SAGE_MEMCHECK_FLAGS")
                if flags is None:
                    flags = "--leak-resolution=high --leak-check=full --num-callers=25 "
                    flags += '''--suppressions="%s" '''%(os.path.join("$SAGE_LOCAL","lib","valgrind","sage.supp"))
            elif opt.massif:
                toolname = "massif"
                flags = os.getenv("SAGE_MASSIF_FLAGS", "--depth=6 ")
            elif opt.cachegrind:
                toolname = "cachegrind"
                flags = os.getenv("SAGE_CACHEGRIND_FLAGS", "")
            elif opt.omega:
                toolname = "exp-omega"
                flags = os.getenv("SAGE_OMEGA_FLAGS", "")
            cmd = "exec valgrind --tool=%s "%(toolname)
            flags += ''' --log-file="%s" ''' % logfile
            if opt.omega:
                toolname = "omega"
            if "%s" in flags:
                flags %= toolname + ".%p" # replace %s with toolname
        cmd += flags + sage_cmd

        self.log(cmd)
        sys.stdout.flush()
        sys.stderr.flush()
        if self.logfile is not None:
            self.logfile.flush()

        if testing:
            return

        import signal, subprocess
        def handle_alrm(sig, frame):
            raise RuntimeError
        signal.signal(signal.SIGALRM, handle_alrm)
        p = subprocess.Popen(cmd, shell=True)
        if opt.timeout > 0:
            signal.alarm(opt.timeout)
        try:
            return p.wait()
        except RuntimeError:
            self.log("    Time out")
            return 4
        except KeyboardInterrupt:
            self.log("    Interrupted")
            return 128
        finally:
            signal.signal(signal.SIGALRM, signal.SIG_IGN)
            if p.returncode is None:
                p.terminate()
            
    def run(self):
        """
        This function is called after initialization to set up and run all doctests.

        EXAMPLES::

            sage: from sage.doctest.control import DocTestDefaults, DocTestController
            sage: import os
            sage: DD = DocTestDefaults()
            sage: filename = os.path.join(os.environ["SAGE_ROOT"], "devel", "sage", "sage", "sets", "non_negative_integers.py")
            sage: DC = DocTestController(DD, [filename])
            sage: DC.run()
            Running doctests with ID ...
            Doctesting 1 file.
            sage -t .../sage/sets/non_negative_integers.py
                [... tests, ... s]
            ----------------------------------------------------------------------
            All tests passed!
            ----------------------------------------------------------------------
            Total time for all tests: ... seconds
                cpu time: ... seconds
                cumulative wall time: ... seconds
            0
        """
        opt = self.options
        L = (opt.gdb, opt.valgrind, opt.massif, opt.cachegrind, opt.omega)
        if any(L):
            if L.count(True) > 1:
                self.log("You may only specify one of gdb, valgrind/memcheck, massif, cachegrind, omega")
                return 2
            return self.run_val_gdb()
        else:
            self.test_safe_directory()
            self.create_run_id()
            self.add_files()
            self.expand_files_into_sources()
            self.filter_sources()
            self.sort_sources()
            self.run_doctests()
            return self.reporter.error_status
예제 #18
0
class Executor(Process):
    def __init__(self, catalog, results, task):
        Process.__init__(self)
        self.catalog = catalog
        self.results = results
        self.task = task
        self.timer = Timer()

    def get_result(self, uuid):
        result = self.results[uuid]
        while result.complete == 0.0:
            time.sleep(0.0005)
            result = self.results[uuid]
        return result

    def wait(self, uuid):
        while self.results[uuid].complete == 0.0:
            time.sleep(0.0005)

    def run(self):
        self.timer.start()
        try:
            if isinstance(self.task, ClassifyTask):
                self.classify()
            elif isinstance(self.task, CorrelateTask):
                self.correlate()
            elif isinstance(self.task, DifferenceTask):
                self.difference()
            elif isinstance(self.task, FeatureSelectTask):
                self.feature_select()
            elif isinstance(self.task, FrequentItemsetsTask):
                self.frequent_itemsets()
            elif isinstance(self.task, IntersectTask):
                self.intersect()
            elif isinstance(self.task, LoadTask):
                self.load()
            elif isinstance(self.task, MergeTask):
                self.merge()
            elif isinstance(self.task, ProjectTask):
                self.project()
            elif isinstance(self.task, SelectTask):
                self.select()
            elif isinstance(self.task, UnionTask):
                self.union()
            else:
                raise NotImplementedError()
        except Exception as e:
            print str(e)
            result = ErrorResult(self.task, str(e))
            self.results[self.task.uuid] = result
        self.timer.stop()
        print 'task' + str(self.task.uuid) + ': ' + str(
            self.timer.time()) + 's'

    def classify(self):
        raise NotImplementedError()

    def correlate(self):
        raise NotImplementedError()

    def difference(self):
        raise NotImplementedError()

    def feature_select(self):
        raise NotImplementedError()

    def frequent_itemsets(self):
        raise NotImplementedError()

    def intersect(self):
        raise NotImplementedError()

    def load(self):
        raise NotImplementedError()

    def merge(self):
        raise NotImplementedError()

    def project(self):
        raise NotImplementedError()

    def select(self):
        raise NotImplementedError()

    def union(self):
        raise NotImplementedError()
    def generate_sudoku(self, target = 25):
        '''
        Genera un sudoku rimuovendo casualemente dei valori fino a che non si ottiene un livello di difficoltà pari a quello specificato dal parametro target
        Ogni 1000 Backtrack ripristina metà valori scelti casualemente tra quelli rimossi

        returns (current_sudoku, len(current_sudoku), cnt_backtrack, time)
        '''
        base_sudoku = self.generate_full_sudoku()
        current_sudoku = Sudoku(base_sudoku.get_dict())
        
        cache = [] # Cache dei valori per il backtrack
        cnt_backtrack = 0
        cnt_step = 0
        single_solution = True;
        timer = Timer()
        
        timer.start()
        while True:
            cnt_step += 1
            #print '----------------------------'
            #print 'Cache size', len(cache)
            # Test di uscita
            if len(current_sudoku) == target and single_solution:
                break;
            #print 'Current values count: ', len(current_sudoku)
            #print 'Single solution: ', single_solution
            #print 'Backtrack', cnt_backtrack
            
            # Quanti valori togliere
            n = len(current_sudoku) / 20
            
            #print 'Prova a rimuovere %d valori' %n
            assert n != 0
            # Togli i numeri
            for i in range(n):
                key = random.choice(current_sudoku.filled_cell())
                cache.append(key)
                current_sudoku.clear_cell(key)
            
            #print 'Cache size', len(cache)

            # Verifica l'unicità della soluzione
            (sols, b, t) = self.solver.solve(current_sudoku, max_sol = 2)
            # Se unica, continua
            if len(sols) == 1:
                single_solution = True
                #print "Rimossi con successo %d elementi" % n
                continue
            # Se più di una, torna indietro
            else:
                #print "Backtrack, sols: %d" % len(sols)
                single_solution = False
                cnt_backtrack += 1

                # Ripristina gli ultimi n valori tolti
                #print 'Restored cache size', len(cache)
                for i in range(n):
                    # Ripristina gli utlimi elementi tolti
                    k = cache[-1]
                    current_sudoku.set_cell(k, base_sudoku.cell(k))
                    cache.pop(-1)

                if cnt_backtrack % 1000 == 0:
                    #print 'Riprista casualmente metà cache'
                    for i in range(len(cache)/2):
                        # Ripristina gli utlimi elementi tolti
                        idx = random.randint(0, len(cache)-1)
                        k = cache[idx]
                        current_sudoku.set_cell(k, base_sudoku.cell(k))
                        cache.pop(idx)

        #print '----------------------------'
        #print 'Backtrack necessari: ', cnt_backtrack
        time = timer.stop()
        return current_sudoku, len(current_sudoku), cnt_step, time
예제 #20
0
class HPETrainBaseRun(TrainBaseRun):
    def setup(self):
        super().setup()

        self.img_size = self.options.hpe.img_size
        self.speed_diagnose = self.options.general.speed_diagnose

        self.model = self.make_model()
        self.heatmap_max = 1

        self.last_results = None
        self.timer = Timer()

    @abstractmethod
    def make_model(self):
        pass

    def iterate(self, data):
        if self.speed_diagnose:
            self.timer.start('preprocess')

        data = self.arrange_data(data)

        if self.speed_diagnose:
            self.timer.stop('preprocess')
            self.timer.start('setting input')

        self.model.set_input(data)

        if self.speed_diagnose:
            self.timer.stop('setting input')
            self.timer.start('optimize')

        self.model.optimize()
        if self.speed_diagnose:
            self.timer.stop('optimize')
            self.timer.print_elapsed_times()

        self.avg_dict.add(self.model.get_current_losses())

        # save the result for visualization
        self.last_results = self.model.get_detached_current_results()
        self.last_data = data

    def save_checkpoint(self, epoch):
        checkpoint = self.model.pack_as_checkpoint()
        self.logger.save_checkpoint(checkpoint, epoch)

    def end_epoch(self):
        pass

    @abstractmethod
    def arrange_data(self, data):
        """ reshape the data for the model. """

    def _visualize_results_as_image(self, results, cur_iter):

        if results is None:
            return

        results = self._select_first_in_batch(results)
        img = results['img']
        joint_out, heatmap_out, heatmap_true, heatmap_reprojected = hpe_util.unpack_data(
            results)

        out_heatmap_img = convert_to_colormap(heatmap_out, 1.0)
        true_heatmap_img = convert_to_colormap(heatmap_true, 1.0)
        reprojected_heatmap_img = convert_to_colormap(heatmap_reprojected, 1.0)
        img = expand_channel(img)

        stacked_img = torch.cat(
            (img, out_heatmap_img, reprojected_heatmap_img, true_heatmap_img),
            3)  # horizontal_stack
        self.visualizer.add_image('train sample', stacked_img, cur_iter)

    def _visualize_network_grad(self, epoch, current_iter):
        grads = self.model.get_grads()
        for tag, val in grads.items():
            self.visualizer.add_histogram(tag, val, epoch)
예제 #21
0
파일: executor.py 프로젝트: twareproj/tware
class Executor(Process):
    def __init__(self, catalog, results, task):
        Process.__init__(self)
        self.catalog = catalog
        self.results = results
        self.task = task
        self.timer = Timer()

    def get_result(self, uuid):
        result = self.results[uuid]
        while result.complete == 0.0:
            time.sleep(0.0005)
            result = self.results[uuid]
        return result

    def wait(self, uuid):
        while self.results[uuid].complete == 0.0:
            time.sleep(0.0005)

    def run(self):
        self.timer.start()
        try:
            if isinstance(self.task, ClassifyTask):
                self.classify()
            elif isinstance(self.task, CorrelateTask):
                self.correlate()
            elif isinstance(self.task, DifferenceTask):
                self.difference()
            elif isinstance(self.task, FeatureSelectTask):
                self.feature_select()
            elif isinstance(self.task, FrequentItemsetsTask):
                self.frequent_itemsets()
            elif isinstance(self.task, IntersectTask):
                self.intersect()
            elif isinstance(self.task, LoadTask):
                self.load()
            elif isinstance(self.task, MergeTask):
                self.merge()
            elif isinstance(self.task, ProjectTask):
                self.project()
            elif isinstance(self.task, SelectTask):
                self.select()
            elif isinstance(self.task, UnionTask):
                self.union()
            else:
                raise NotImplementedError()
        except Exception as e:
            print str(e)
            result = ErrorResult(self.task, str(e))
            self.results[self.task.uuid] = result
        self.timer.stop()
        print 'task' + str(self.task.uuid) + ': ' + str(self.timer.time()) + 's'

    def classify(self):
        raise NotImplementedError()

    def correlate(self):
        raise NotImplementedError()

    def difference(self):
        raise NotImplementedError()

    def feature_select(self):
        raise NotImplementedError()

    def frequent_itemsets(self):
        raise NotImplementedError()

    def intersect(self):
        raise NotImplementedError()

    def load(self):
        raise NotImplementedError()

    def merge(self):
        raise NotImplementedError()

    def project(self):
        raise NotImplementedError()

    def select(self):
        raise NotImplementedError()

    def union(self):
        raise NotImplementedError()