Ejemplo n.º 1
0
    def __enter_class_context_managers(self, fixture_methods, callback):
        """Transform each fixture_method into a context manager with contextlib.contextmanager, enter them recursively, and call callback"""
        if fixture_methods:
            fixture_method = fixture_methods[0]
            ctm = contextmanager(fixture_method)()

            enter_result = TestResult(fixture_method)
            enter_result.start()
            self.fire_event(self.EVENT_ON_RUN_CLASS_SETUP_METHOD, enter_result)
            if self.__execute_block_recording_exceptions(ctm.__enter__, enter_result, is_class_level=True):
                enter_result.end_in_success()
            self.fire_event(self.EVENT_ON_COMPLETE_CLASS_SETUP_METHOD, enter_result)

            self.__enter_context_managers(fixture_methods[1:], callback)

            exit_result = TestResult(fixture_method)
            exit_result.start()
            self.fire_event(self.EVENT_ON_RUN_CLASS_TEARDOWN_METHOD, exit_result)
            if self.__execute_block_recording_exceptions(
                lambda: ctm.__exit__(None, None, None), exit_result, is_class_level=True
            ):
                exit_result.end_in_success()
            self.fire_event(self.EVENT_ON_COMPLETE_CLASS_TEARDOWN_METHOD, exit_result)
        else:
            callback()
Ejemplo n.º 2
0
    def test_request_context():
        with contextmanager(appfactory)() as app:
            templates = []

            def capture(sender, template, context):
                templates.append((template, context))

            @jinja_rendered.connect_via(app)
            def signal_jinja(sender, template, context):
                template_rendered.send(None, template=template.name,
                                       context=context)

            try:
                from flaskext.genshi import template_generated
            except ImportError:
                pass
            else:
                @template_generated.connect_via(app)
                def signal_genshi(sender, template, context):
                    template_rendered.send(None, template=template.filename,
                                       context=context)

            with app_context(app) as client:
                with template_rendered.connected_to(capture):
                    yield client, templates
Ejemplo n.º 3
0
 def __enter_context_managers(self, fixture_methods, callback):
     """Transform each fixture_method into a context manager with contextlib.contextmanager, enter them recursively, and call callback"""
     if fixture_methods:
         with contextmanager(fixture_methods[0])():
             self.__enter_context_managers(fixture_methods[1:], callback)
     else:
         callback()
Ejemplo n.º 4
0
def get_evaluation_context_getter():
    if K.backend() == 'tensorflow':
        import tensorflow as tf
        return tf.get_default_graph().as_default

    if K.backend() == 'theano':
        return contextmanager(lambda: (yield))
Ejemplo n.º 5
0
def control_handler(f):
    def g(*args):
        with manage('allocator', 'var_manager', 'code_maker'):
            yield
            output = get_context().code_maker.output
        f(output, *args)
    return contextmanager(g)
Ejemplo n.º 6
0
 def _context(self, *contexts, **kwargs):
     with ExitStack() as stack:
         res = []
         for ctxname in contexts:
             ctx = contextmanager(getattr(self, '_context_%s' % ctxname))
             res.append(stack.enter_context(ctx(**kwargs)))
         yield res if len(res) > 1 else res[0]
Ejemplo n.º 7
0
def assert_redirects():
    manager = contextmanager(TestContext())
    with manager() as client:
        response = client.get("/bouncer/")

    # First Django's API of passing the URL in by itself should work.
    redirects(response, "https://example.com:1234/foo/?a=b#bar")

    valids = [
        {"port": 1234},
        {"scheme": "https"},
        {"domain": "example.com"},
        {"query": "a=b"},
        {"path": "/foo/"},
        {"fragment": "bar"},
        {"url": "https://example.com:1234/foo/?a=b#bar"},
    ]

    for valid in valids:
        redirects(response, **valid)

    invalids = [
        {"port": 5678},
        {"scheme": "http"},
        {"domain": "example.net"},
        {"query": "c=d"},
        {"path": "/baz/"},
        {"fragment": "ben"},
        {"url": "http://example.net:5678/baz/?c=d#ben"},
    ]

    for invalid in invalids:
        with attest.raises(AssertionError):
            redirects(response, **invalid)
Ejemplo n.º 8
0
def g_decorator(generator_expr):
    """
    Converts generator expression into a decorator

    Takes in a generator expression, such as one accepted by
    contextlib.contextmanager, converts it to a context manager,
    and returns a decorator equivalent to being within that
    context manager.

    TODO do something with yielded value

    Example:

    @g_decorator
    def foo():
        print("Hello")
        yield
        print("World")

    @foo()
    def bar():
        print("Something")
    """
    cm = contextmanager(generator_expr)

    @to_decorator
    def wrapped_func(func, args, kwargs, *outer_args, **outer_kwargs):
        with cm(*outer_args, **outer_kwargs) as yielded:
            return func(*args, **kwargs)

    return wrapped_func
Ejemplo n.º 9
0
def test_context_supports_fixtures():
    with attest.raises(Thing.DoesNotExist):
        Thing.objects.get(name="loaded from fixture")

    manager = contextmanager(TestContext(fixtures=["tests"]))
    with manager():
        Thing.objects.get(name="loaded from fixture")
Ejemplo n.º 10
0
def view(f):
    """
    Decorate a view function.
    The decorated function has to recieve a Page object
    as a first parameter.
    The resulting function will recieve all parameters *except*
    the Page object, and will always return the rendered Page.
    
    A generator function can be decorated for a different creature.
    This results in a custom element, registered on all Page objects,
    That can be used using `with` as normal.
    This usage returns None.

    Example usage:

    @view
    def main(p, title):
        with p.div(klass='container'):
            with p.h1():
                p.text(title)

    >>> main('Foo')
    >>> '<div class="container"><h1>Foo</h1></div>'
    """
    if is_gen(f):
        Page.register(ctx.contextmanager(f))
        return

    @wraps(f)
    def _inner(*args, **kwargs):
        p = Page()
        f(p, *args, **kwargs)
        return str(p)
    return _inner
Ejemplo n.º 11
0
    def test_contextmanager(self):
        """Test unwrapping a context manager."""

        self.assertEqual(
            unwrap_function(contextmanager(my_function)),
            my_function
        )
    def context(self, func):
        """Decorate a function to act as a context.

        :param function func: the function that describes the context
        """
        func = contextmanager(func)
        self._contexts.append(func)
Ejemplo n.º 13
0
def make_connect_cmd(con):
    def go(opts=None):
        mc = con(opts)
        try:
            yield mc
        finally:
            mc.disconnect_all()
    return contextmanager(go)
Ejemplo n.º 14
0
def bn_impossible():
    criterion = bn_impossible_criterion()
    image = bn_image()
    label = bn_label()

    cm_model = contextmanager(bn_model)
    with cm_model() as model:
        yield Adversarial(model, criterion, image, label)
Ejemplo n.º 15
0
def gl_bn_adversarial():
    criterion = bn_criterion()
    image = bn_image()
    label = bn_label()

    cm_model = contextmanager(gl_bn_model)
    with cm_model() as model:
        yield Adversarial(model, criterion, image, label)
Ejemplo n.º 16
0
def test_context_does_transaction_rollback():
    manager = contextmanager(TestContext())
    with manager():
        # create a thing, but it should be rolled back when the context exits
        thing = Thing.objects.create(name="foo")

    with attest.raises(Thing.DoesNotExist):
        Thing.objects.get(pk=thing.pk)
Ejemplo n.º 17
0
def gl_bn_model():
    """Same as bn_model but without gradient.

    """
    cm_model = contextmanager(bn_model)
    with cm_model() as model:
        model = GradientLess(model)
        yield model
Ejemplo n.º 18
0
def main():
    parser = argparse.ArgumentParser()
    parser = argparse.ArgumentParser(description="Bibliography database manipulation")

    parser.add_argument("--debug", action="store_true")

    parser.add_argument("--logging-level", "-L",
                        help="Logging level: CRITICAL, ERROR, WARNING, INFO, DEBUG",
                        metavar="LEVEL", type=str, default="WARNING")

    parser.add_argument("--data-dir", help="Path to articles directory", type=str, default=None)

    subparsers = parser.add_subparsers(title='Commands', dest='_commandName')

    for cmdType in Registry.commands:
        cmdType().args(subparsers)

    argcomplete.autocomplete(parser)

    args = parser.parse_args()
    if args.debug:
        msg.setup(level="DEBUG")
        try:
            from ipdb import launch_ipdb_on_exception
        except ModuleNotFoundError:
            from contextlib import contextmanager
            def noop():
                yield
            launch_ipdb_on_exception = contextmanager(noop)
    else:
        msg.setup(level=args.logging_level)

    ddir = Database.getDataDir(dataDir=args.data_dir)
    if ddir:
        RequestCache(os.path.join(ddir, ".cache.pkl"))

    try:
        if hasattr(args, "func"):
            if args.debug:
                with launch_ipdb_on_exception():
                    args.func(args)
            else:
                args.func(args)
        else:
            parser.print_usage()
    except UserException as e:
        msg.error("Error: %s", e)
        sys.exit(1)
    except AbortException:
        msg.error("Aborted")
        sys.exit(1)
    except (WorkExistsException, RepositoryException) as e:
        msg.error(str(e))
        sys.exit(1)
    except:
        t,v,_ = sys.exc_info()
        msg.critical("Unhandled exception: %s(%s)", t.__name__, v)
        sys.exit(1)
Ejemplo n.º 19
0
 def __call__(self, filename, *args, **kwargs):
     print('MockOpen called')
     if filename.startswith(self.test_dir):
         if filename not in self.files:
             print('Mocking %s' % filename)
             self.files[filename] = StringIO.StringIO()
         return contextlib.contextmanager(yields)(self.files[filename])
     else:
         return self.old_open(filename, *args, **kwargs)
Ejemplo n.º 20
0
def reentrentcontext(func):
    context = contextmanager(func)
    entered = []
    def decorated(*a, **k):
        if entered == []:
            entered.append(True) # I miss nonlocal
            return context(*a, **k)
        return noop_context()
    return decorated
Ejemplo n.º 21
0
def bn_adversarial_mae():
    criterion = bn_criterion()
    image = bn_image()
    label = bn_label()
    distance = MAE

    cm_model = contextmanager(bn_model)
    with cm_model() as model:
        yield Adversarial(model, criterion, image, label, distance=distance)
Ejemplo n.º 22
0
  def setUp(self):
    self.a = A()
    self.a.x = 0

    def incr():
      self.a.x += 1
      yield
      self.a.x += 1
    self.incr = contextmanager(incr)
    self.with_lk = with_(self.incr, lambda: None)
Ejemplo n.º 23
0
 def NewCookieJar(self):
   """Makes a context manager that sets up a cookie jar for DoGet/DoPost."""
   def SetCookieJar():
     original_cookie_jar = self.cookie_jar
     self.cookie_jar = cookielib.CookieJar()
     try:
       yield self.cookie_jar
     finally:
       self.cookie_jar = original_cookie_jar
   return contextlib.contextmanager(SetCookieJar)()
Ejemplo n.º 24
0
    def _wrapper(func):
        func.renders_operations = operations

        if any(isinstance(op, ChildlessOperation) for op in operations):
            if not all(isinstance(op, ChildlessOperation) for op in operations):
                raise Exception("Cannot mix ChildlessOperations and normal Operations")

            return func

        return contextlib.contextmanager(func)
Ejemplo n.º 25
0
Archivo: util.py Proyecto: JZQT/tornado
def subTest(test, *args, **kwargs):
    """Compatibility shim for unittest.TestCase.subTest.

    Usage: ``with tornado.test.util.subTest(self, x=x):``
    """
    try:
        subTest = test.subTest  # py34+
    except AttributeError:
        subTest = contextlib.contextmanager(lambda *a, **kw: (yield))
    return subTest(*args, **kwargs)
Ejemplo n.º 26
0
Archivo: utils.py Proyecto: tony/case
def decorator(predicate):
    context = contextmanager(predicate)

    @wraps(predicate)
    def take_arguments(*pargs, **pkwargs):

        @wraps(predicate)
        def decorator(cls):
            if inspect.isclass(cls):
                orig_setup = cls.setUp
                orig_teardown = cls.tearDown

                @wraps(cls.setUp)
                def around_setup(*args, **kwargs):
                    try:
                        contexts = args[0].__rb3dc_contexts__
                    except AttributeError:
                        contexts = args[0].__rb3dc_contexts__ = []
                    p = context(*pargs, **pkwargs)
                    p.__enter__()
                    contexts.append(p)
                    return orig_setup(*args, **kwargs)
                around_setup.__wrapped__ = cls.setUp
                cls.setUp = around_setup

                @wraps(cls.tearDown)
                def around_teardown(*args, **kwargs):
                    try:
                        contexts = args[0].__rb3dc_contexts__
                    except AttributeError:
                        pass
                    else:
                        for context in contexts:
                            context.__exit__(*sys.exc_info())
                    orig_teardown(*args, **kwargs)
                around_teardown.__wrapped__ = cls.tearDown
                cls.tearDown = around_teardown

                return cls
            else:
                @wraps(cls)
                def around_case(self, *args, **kwargs):
                    with context(*pargs, **pkwargs) as context_args:
                        context_args = context_args or ()
                        if not isinstance(context_args, tuple):
                            context_args = (context_args,)
                        return cls(*(self,) + args + context_args, **kwargs)
                return around_case

        if len(pargs) == 1 and callable(pargs[0]):
            fun, pargs = pargs[0], ()
            return decorator(fun)
        return _CallableContext(context, pargs, pkwargs, decorator)
    assert take_arguments.__wrapped__
    return take_arguments
Ejemplo n.º 27
0
def template_rendering_tracking_works():
    manager = contextmanager(TestContext())
    with manager() as client:
        response = client.get("/")
        assert response.content == b"rendered from template.html\n"
        if hasattr(response, "templates"):
            # Django >= 1.3
            assert [t.name for t in response.templates] == ["template.html"]
        else:
            # Django <= 1.2
            assert response.template.name == "template.html"
Ejemplo n.º 28
0
def make_contextmanager(fn):
    if inspect.isgeneratorfunction(fn):
        return contextmanager(fn)

    if fn is None:
        fn = lambda *a, **kw: None

    @contextmanager
    @functools.wraps(fn)
    def wrapper(*a, **kw):
        yield fn(*a, **kw)
    return wrapper
Ejemplo n.º 29
0
 def __call__(self, filename, mode, *args, **kwargs):
     if filename.startswith(self.test_dir):
         if filename not in self.files or mode in ('w', 'w+'):
             self.files[filename] = StringIO.StringIO()
         fakefile = self.files[filename]
         if mode in ('r', 'r+'):
             fakefile.seek(0)
         else:
             fakefile.seek(0, os.SEEK_END)
         return contextlib.contextmanager(yields)(fakefile)
     else:
         return self.old_open(filename, *args, **kwargs)
Ejemplo n.º 30
0
def patch_modules_registry():
    """Backport the Registry.cursor helper."""
    def cursor(self, auto_commit=True):
        cr = self.db.cursor()
        try:
            yield cr
            if auto_commit:
                cr.commit()
        finally:
            cr.close()

    openerp.modules.registry.Registry.cursor = contextmanager(cursor)
Ejemplo n.º 31
0
def test_collectstats():
    """
    Some extra tests for `collectstats`.
    """
    # Test for getting extra data even if call raised an exception.
    class RaiseException(object):
        @staticmethod
        def __extra_stats__():
            return {"hello", "world!"}

        @collectstats(collect_extra=True)
        def doit(self):
            raise NotImplementedError("And never will be!")

    recorder = ApiCallsRecorder(suspend_save=True)
    with recorder:
        try:
            RaiseException().doit()
            assert True, "Call to RaiseException().doit() should have raised an exception"
        except NotImplementedError:
            pass

    assert recorder.api_stats[0].extra_stats == {"hello", "world!"}

    # Test for objects with `__getattr__`.
    class IHazGetAttr(object):
        def __getattr__(self, item):
            return self

        @collectstats(collect_extra=True)
        def bar(self):
            return self.__name__

    recorder = ApiCallsRecorder(suspend_save=True)
    with recorder:
        IHazGetAttr().bar()

    assert recorder.api_stats[0].extra_stats == "TypeError(\"'IHazGetAttr' object is not callable\")"

    # Test that there are no new lines in the json dumps of saved ApiCalls. Newlines will mess up reading the file.
    class LotsOfNewLines(object):
        def __extra_stats__(self):
            return "\nHello,\nWorld\n!\n"

        @collectstats("\nClass\nName\n", collect_extra=True)
        def bar(self, a, b):
            return "\nYes\nHello\n" + a + b

    import StringIO
    import mock

    stringio = StringIO.StringIO()  # ContextStringIO()

    with mock.patch.object(ApiCallsRecorder, "_get_save_file",
                           new=staticmethod(contextmanager(lambda: (yield stringio)))):
        LotsOfNewLines().bar("\nA\n", b="\nB\n")

    # TODO: separate writing logic to a different class so that we don't test implementation details here.
    assert all('\n' not in buf for buf in stringio.buflist[:-1])
    assert '\n' not in stringio.buflist[-1][:-1]
    assert stringio.buflist[-1][-1] == '\n'
Ejemplo n.º 32
0
 def __init__(self):
     self.count = 0
     self.func_cm = contextmanager(func)
     self._lock = threading.RLock()
Ejemplo n.º 33
0
    test_cluster.ddl_check_query(
        instance,
        "CREATE TABLE test_atomic.rvcmt ON CLUSTER cluster (n UInt64, m Int8, k UInt64) ENGINE=ReplicatedVersionedCollapsingMergeTree(m, k) ORDER BY n"
    )
    test_cluster.ddl_check_query(
        instance, "DROP DATABASE test_atomic ON CLUSTER cluster")

    test_cluster.ddl_check_query(
        instance,
        "CREATE DATABASE test_ordinary ON CLUSTER cluster ENGINE=Ordinary")
    assert "are supported only for ON CLUSTER queries with Atomic database engine" in \
           instance.query_and_get_error("CREATE TABLE test_ordinary.rmt ON CLUSTER cluster (n UInt64, s String) ENGINE=ReplicatedMergeTree ORDER BY n")
    assert "are supported only for ON CLUSTER queries with Atomic database engine" in \
           instance.query_and_get_error("CREATE TABLE test_ordinary.rmt ON CLUSTER cluster (n UInt64, s String) ENGINE=ReplicatedMergeTree('/{shard}/{uuid}/', '{replica}') ORDER BY n")
    test_cluster.ddl_check_query(
        instance,
        "CREATE TABLE test_ordinary.rmt ON CLUSTER cluster (n UInt64, s String) ENGINE=ReplicatedMergeTree('/{shard}/{table}/', '{replica}') ORDER BY n"
    )
    assert instance.query("SHOW CREATE test_ordinary.rmt FORMAT TSVRaw") == \
           "CREATE TABLE test_ordinary.rmt\n(\n    `n` UInt64,\n    `s` String\n)\nENGINE = ReplicatedMergeTree('/{shard}/rmt/', '{replica}')\nORDER BY n\nSETTINGS index_granularity = 8192\n"
    test_cluster.ddl_check_query(
        instance, "DROP DATABASE test_ordinary ON CLUSTER cluster")
    test_cluster.pm_random_drops.push_rules(rules)


if __name__ == '__main__':
    with contextmanager(test_cluster)() as ctx_cluster:
        for name, instance in list(ctx_cluster.instances.items()):
            print(name, instance.ip_address)
        input("Cluster created, press any key to destroy...")
Ejemplo n.º 34
0
def binarized2_bn_model_fixture():
    cm_model = contextmanager(binarized2_bn_model)
    with cm_model() as model:
        yield model
Ejemplo n.º 35
0
def main():
    parser = argparse.ArgumentParser()
    parser = argparse.ArgumentParser(
        description="Bibliography database manipulation")

    parser.add_argument("--debug", action="store_true")

    parser.add_argument(
        "--logging-level",
        "-L",
        help="Logging level: CRITICAL, ERROR, WARNING, INFO, DEBUG",
        metavar="LEVEL",
        type=str,
        default="WARNING")

    parser.add_argument("--data-dir",
                        help="Path to articles directory",
                        type=str,
                        default=None)

    subparsers = parser.add_subparsers(title='Commands', dest='_commandName')

    for cmdType in Registry.commands:
        cmdType().args(subparsers)

    argcomplete.autocomplete(parser)

    args = parser.parse_args()
    if args.debug:
        msg.setup(level="DEBUG")
        try:
            from ipdb import launch_ipdb_on_exception
        except ModuleNotFoundError:
            from contextlib import contextmanager

            def noop():
                yield

            launch_ipdb_on_exception = contextmanager(noop)
    else:
        msg.setup(level=args.logging_level)

    ddir = Database.getDataDir(dataDir=args.data_dir)
    if ddir:
        RequestCache(os.path.join(ddir, ".cache.pkl"))

    try:
        if hasattr(args, "func"):
            if args.debug:
                with launch_ipdb_on_exception():
                    args.func(args)
            else:
                args.func(args)
        else:
            parser.print_usage()
    except UserException as e:
        msg.error("Error: %s", e)
        sys.exit(1)
    except AbortException:
        msg.error("Aborted")
        sys.exit(1)
    except (WorkExistsException, RepositoryException) as e:
        msg.error(str(e))
        sys.exit(1)
    except:
        t, v, _ = sys.exc_info()
        msg.critical("Unhandled exception: %s(%s)", t.__name__, v)
        sys.exit(1)
Ejemplo n.º 36
0
def create_slack_finding_sender(args, db_file):
    if not args.cache_only and args.slack_webhook:
        return SlackFindingSender(args.slack_webhook, db_file)
    else:
        return contextmanager(lambda: iter([None]))()
Ejemplo n.º 37
0
import contextlib
import logging

import numpy as np

import nengo.utils.numpy as npext
from nengo.builder import Builder
from nengo.connection import Connection
from nengo.ensemble import Ensemble
from nengo.network import Network
from nengo.node import Node
from nengo.probe import Probe
from nengo.utils.progress import Progress

logger = logging.getLogger(__name__)
nullcontext = contextlib.contextmanager(lambda: (yield))


@Builder.register(Network)  # noqa: C901
def build_network(model, network, progress=None):
    """Builds a `.Network` object into a model.

    The network builder does this by mapping each high-level object to its
    associated signals and operators one-by-one, in the following order:

    1. Ensembles, nodes, neurons
    2. Subnetworks (recursively)
    3. Connections, learning rules
    4. Probes

    Before calling any of the individual objects' build functions, random
Ejemplo n.º 38
0
from xmodule.modulestore import ModuleStoreEnum
from xmodule.modulestore.django import modulestore
from xmodule.modulestore.tests.django_utils import ModuleStoreTestCase
from xmodule.modulestore.tests.factories import CourseFactory, ItemFactory, check_mongo_calls

from openedx.core.djangolib.testing.utils import skip_unless_lms
from common.djangoapps.student.tests.factories import AdminFactory, UserFactory

from .. import DEFAULT_FIELDS, OPTIONAL_FIELDS, PathItem
from ..models import Bookmark, XBlockCache, parse_path_data
from .factories import BookmarkFactory

EXAMPLE_USAGE_KEY_1 = 'i4x://org.15/course_15/chapter/Week_1'
EXAMPLE_USAGE_KEY_2 = 'i4x://org.15/course_15/chapter/Week_2'

noop_contextmanager = contextmanager(lambda x: (yield))  # pylint: disable=invalid-name


class BookmarksTestsBase(ModuleStoreTestCase):
    """
    Test the Bookmark model.
    """
    ALL_FIELDS = DEFAULT_FIELDS + OPTIONAL_FIELDS
    STORE_TYPE = ModuleStoreEnum.Type.split
    TEST_PASSWORD = '******'

    ENABLED_CACHES = ['default', 'mongo_metadata_inheritance', 'loc_cache']

    def setUp(self):
        super().setUp()
Ejemplo n.º 39
0
def documented_contextmanager(func):
    wrapper = contextmanager(func)
    wrapper.undecorated = func
    return wrapper
Ejemplo n.º 40
0
 def __init__(self, func):
     self.count = 0
     self.func_cm = contextmanager(func)
     self._lock = RLock()
Ejemplo n.º 41
0
 def eg_bn_model():
     cm_model = contextmanager(bn_model)
     with cm_model() as model:
         gradient_estimator = GradientEstimator(epsilon=0.01)
         model = ModelWithEstimatedGradients(model, gradient_estimator)
         yield model
Ejemplo n.º 42
0
def run(dataset: Dataset, config: TaskConfig):
    log.info("\n**** H2O AutoML ****\n")
    # Mapping of benchmark metrics to H2O metrics
    metrics_mapping = dict(acc='mean_per_class_error',
                           auc='AUC',
                           logloss='logloss',
                           mae='mae',
                           mse='mse',
                           r2='r2',
                           rmse='rmse',
                           rmsle='rmsle')
    sort_metric = metrics_mapping[
        config.metric] if config.metric in metrics_mapping else None
    if sort_metric is None:
        # TODO: Figure out if we are going to blindly pass metrics through, or if we use a strict mapping
        log.warning("Performance metric %s not supported, defaulting to AUTO.",
                    config.metric)

    try:
        training_params = {
            k: v
            for k, v in config.framework_params.items()
            if not k.startswith('_')
        }
        nthreads = config.framework_params.get('_nthreads', config.cores)
        jvm_memory = str(
            round(config.max_mem_size_mb * 2 /
                  3)) + "M"  # leaving 1/3rd of available memory for XGBoost

        log.info("Starting H2O cluster with %s cores, %s memory.", nthreads,
                 jvm_memory)
        max_port_range = 49151
        min_port_range = 1024
        rnd_port = os.getpid() % (max_port_range -
                                  min_port_range) + min_port_range
        port = config.framework_params.get('_port', rnd_port)

        h2o.init(
            nthreads=nthreads,
            port=port,
            min_mem_size=jvm_memory,
            max_mem_size=jvm_memory,
            strict_version_check=config.framework_params.get(
                '_strict_version_check', True)
            # log_dir=os.path.join(config.output_dir, 'logs', config.name, str(config.fold))
        )

        # Load train as an H2O Frame, but test as a Pandas DataFrame
        log.debug("Loading train data from %s.", dataset.train.path)
        train = h2o.import_file(dataset.train.path,
                                destination_frame=frame_name('train', config))
        # train.impute(method='mean')
        log.debug("Loading test data from %s.", dataset.test.path)
        test = h2o.import_file(dataset.test.path,
                               destination_frame=frame_name('test', config))
        # test.impute(method='mean')

        log.info("Running model on task %s, fold %s.", config.name,
                 config.fold)
        log.debug(
            "Running H2O AutoML with a maximum time of %ss on %s core(s), optimizing %s.",
            config.max_runtime_seconds, config.cores, sort_metric)

        aml = H2OAutoML(
            max_runtime_secs=config.max_runtime_seconds,
            max_runtime_secs_per_model=round(
                config.max_runtime_seconds /
                2),  # to prevent timeout on ensembles
            sort_metric=sort_metric,
            seed=config.seed,
            **training_params)

        monitor = (
            BackendMemoryMonitoring(
                frequency_seconds=rconfig().monitoring.frequency_seconds,
                check_on_exit=True,
                verbosity=rconfig().monitoring.verbosity)
            if config.framework_params.get('_monitor_backend', False)
            # else contextlib.nullcontext  # Py 3.7+ only
            else contextlib.contextmanager(iter)([0]))
        with Timer() as training:
            with monitor:
                aml.train(y=dataset.target.index, training_frame=train)

        if not aml.leader:
            raise NoResultError(
                "H2O could not produce any model in the requested time.")

        save_predictions(aml, test, dataset=dataset, config=config)
        save_artifacts(aml, dataset=dataset, config=config)

        return dict(models_count=len(aml.leaderboard),
                    training_duration=training.duration)

    finally:
        if h2o.connection():
            # h2o.remove_all()
            h2o.connection().close()
        if h2o.connection().local_server:
            h2o.connection().local_server.shutdown()
Ejemplo n.º 43
0
def get_data_from_url(url: str):
    return request.urlopen(url)


def _change_directory(destination_directory):
    cwd = os.getcwd()
    os.chdir(destination_directory)
    try:
        yield
    except:
        pass
    finally:
        os.chdir(cwd)


cd = contextlib.contextmanager(_change_directory)


def get_indra_statements_from_directory(directory: str) -> Iterable[Influence]:
    """ Returns a list of INDRA statements from a directory containing JSON-LD
    output from Eidos. """
    return chain.from_iterable(
        map(
            lambda ep: ep.statements,
            map(eidos.process_json_ld_file, tqdm(glob(directory))),
        ))


def grouper(iterable, n, fillvalue=None):
    "Collect data into fixed-length chunks or blocks"
    # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
Ejemplo n.º 44
0
class WMConnectionBase(object):
    """
    Generic db connection and transaction methods used by all of the WMCore classes.
    """
    def __init__(self, daoPackage, logger=None, dbi=None):
        """
        ___init___

        Initialize all the database connection attributes and the logging
        attritbutes.  Create a DAO factory for given daoPackage as well. Finally,
        check to see if a transaction object has been created.  If none exists,
        create one but leave the transaction closed.
        """
        myThread = threading.currentThread()
        if logger:
            self.logger = logger
        else:
            self.logger = myThread.logger
        if dbi:
            self.dbi = dbi
        else:
            self.dbi = myThread.dbi

        self.daofactory = DAOFactory(package=daoPackage,
                                     logger=self.logger,
                                     dbinterface=self.dbi)

        if "transaction" not in dir(myThread):
            myThread.transaction = Transaction(self.dbi)

        return

    def getDBConn(self):
        """
        _getDBConn_

        Retrieve the database connection that is associated with the current
        dataabase transaction.
        It transaction exists, it will return connection
        which that transaction belong to.
        This won't create the transaction if it doesn't exist, it will just return
        None.
        """
        myThread = threading.currentThread()

        if "transaction" not in dir(myThread):
            return None

        return myThread.transaction.conn

    def beginTransaction(self):
        """
        _beginTransaction_

        Begin a database transaction if one does not already exist.
        """
        myThread = threading.currentThread()

        if "transaction" not in dir(myThread):
            myThread.transaction = Transaction(self.dbi)
            return False

        if myThread.transaction.transaction == None:
            myThread.transaction.begin()
            return False

        return True

    def existingTransaction(self):
        """
        _existingTransaction_

        Return True if there is an open transaction, False otherwise.
        """
        myThread = threading.currentThread()

        if "transaction" not in dir(myThread):
            return False
        elif myThread.transaction.transaction != None:
            return True

        return False

    def commitTransaction(self, existingTransaction):
        """
        _commitTransaction_

        Commit a database transaction that was begun by self.beginTransaction().
        """
        if not existingTransaction:
            myThread = threading.currentThread()
            myThread.transaction.commit()

        return

    def __getstate__(self):
        """
        __getstate__

        The database connection information isn't pickleable, so we to kill that
        before we attempt to pickle.
        """
        self.dbi = None
        self.logger = None
        self.daofactory = None
        return self.__dict__

    def transactionContext(self):
        """
        Returns a transaction as a ContextManager

        Usage:
            with transactionContext():
                databaseCode1()
                databaseCode2()

        Equates to beginTransaction() followed by either
        commitTransaction or a rollback
        """
        existingTransaction = self.beginTransaction()
        try:
            yield existingTransaction
        except:
            # responsibility for rolling back is on the transaction starter
            if not existingTransaction:
                self.logger.error('Exception caught, rolling back transaction')
                threading.currentThread().transaction.rollback()
            raise
        else:
            # only commits if transaction started by this invocation
            self.commitTransaction(existingTransaction)

    try:
        transactionContext = contextmanager(transactionContext)
    except NameError:
        pass

    def rollbackTransaction(self, existingTransaction):
        """Rollback transaction if we started it"""
        if not existingTransaction:
            threading.currentThread().transaction.rollback()
Ejemplo n.º 45
0
    __BUCKET_ERRORS = []

    def threshold_suppress():
        """Suppress errors as long as they stay below a threshold."""
        try:
            yield
        except exceptions, ex:
            ts = int(time.time())
            timebucket = ts - (ts % interval)
            if __BUCKET_ERRORS:
                (lastbucket, errors) = __BUCKET_ERRORS[0]
            else:
                lastbucket = None

            if lastbucket != timebucket:
                errors = defaultdict(int)
                __BUCKET_ERRORS.insert(0, (timebucket, errors))
                while len(__BUCKET_ERRORS) > 1:
                    __BUCKET_ERRORS.pop()

            errors[type(ex)] += 1
            if errors[type(ex)] < threshold:
                stats.incr('%s_suppressed' % type(ex).__name__)
                logging.exception("Suppressing error: %s", ex)
                return
            logging.debug("Too many %s errors, raising", type(ex))
            stats.incr('%s_suppress_failures' % type(ex).__name__)
            raise

    return contextmanager(threshold_suppress)
Ejemplo n.º 46
0
def get_install_actions(prefix,
                        specs,
                        env,
                        retries=0,
                        subdir=None,
                        verbose=True,
                        debug=False,
                        locking=True,
                        bldpkgs_dirs=None,
                        timeout=90,
                        disable_pip=False,
                        max_env_retry=3,
                        output_folder=None,
                        channel_urls=None):
    global cached_actions
    global last_index_ts
    actions = {}
    log = utils.get_logger(__name__)
    conda_log_level = logging.WARN
    specs = list(specs)
    if verbose:
        capture = contextlib.contextmanager(lambda: (yield))
    elif debug:
        capture = contextlib.contextmanager(lambda: (yield))
        conda_log_level = logging.DEBUG
    else:
        capture = utils.capture
    for feature, value in feature_list:
        if value:
            specs.append('%s@' % feature)

    bldpkgs_dirs = ensure_list(bldpkgs_dirs)

    index, index_ts = get_build_index(subdir,
                                      list(bldpkgs_dirs)[0],
                                      output_folder=output_folder,
                                      channel_urls=channel_urls,
                                      debug=debug,
                                      verbose=verbose,
                                      locking=locking,
                                      timeout=timeout)
    specs = tuple(utils.ensure_valid_spec(spec) for spec in specs)

    if ((specs, env, subdir, channel_urls, disable_pip) in cached_actions
            and last_index_ts >= index_ts):
        actions = cached_actions[(specs, env, subdir, channel_urls,
                                  disable_pip)].copy()
        if "PREFIX" in actions:
            actions['PREFIX'] = prefix
    elif specs:
        # this is hiding output like:
        #    Fetching package metadata ...........
        #    Solving package specifications: ..........
        with utils.LoggingContext(conda_log_level):
            with capture():
                try:
                    actions = install_actions(prefix, index, specs, force=True)
                except NoPackagesFoundError as exc:
                    raise DependencyNeedsBuildingError(exc, subdir=subdir)
                except (SystemExit, PaddingError, LinkError,
                        DependencyNeedsBuildingError, CondaError,
                        AssertionError) as exc:
                    if 'lock' in str(exc):
                        log.warn(
                            "failed to get install actions, retrying.  exception was: %s",
                            str(exc))
                    elif ('requires a minimum conda version' in str(exc)
                          or 'link a source that does not' in str(exc)
                          or isinstance(exc, AssertionError)):
                        locks = utils.get_conda_operation_locks(
                            locking, bldpkgs_dirs, timeout)
                        with utils.try_acquire_locks(locks, timeout=timeout):
                            pkg_dir = str(exc)
                            folder = 0
                            while os.path.dirname(
                                    pkg_dir) not in pkgs_dirs and folder < 20:
                                pkg_dir = os.path.dirname(pkg_dir)
                                folder += 1
                            log.warn(
                                "I think conda ended up with a partial extraction for %s. "
                                "Removing the folder and retrying", pkg_dir)
                            if pkg_dir in pkgs_dirs and os.path.isdir(pkg_dir):
                                utils.rm_rf(pkg_dir)
                    if retries < max_env_retry:
                        log.warn(
                            "failed to get install actions, retrying.  exception was: %s",
                            str(exc))
                        actions = get_install_actions(
                            prefix,
                            tuple(specs),
                            env,
                            retries=retries + 1,
                            subdir=subdir,
                            verbose=verbose,
                            debug=debug,
                            locking=locking,
                            bldpkgs_dirs=tuple(bldpkgs_dirs),
                            timeout=timeout,
                            disable_pip=disable_pip,
                            max_env_retry=max_env_retry,
                            output_folder=output_folder,
                            channel_urls=tuple(channel_urls))
                    else:
                        log.error(
                            "Failed to get install actions, max retries exceeded."
                        )
                        raise
        if disable_pip:
            for pkg in ('pip', 'setuptools', 'wheel'):
                # specs are the raw specifications, not the conda-derived actual specs
                #   We're testing that pip etc. are manually specified
                if not any(
                        re.match('^%s(?:$| .*)' % pkg, str(dep))
                        for dep in specs):
                    actions['LINK'] = [
                        spec for spec in actions['LINK'] if spec.name != pkg
                    ]
        utils.trim_empty_keys(actions)
        cached_actions[(specs, env, subdir, channel_urls,
                        disable_pip)] = actions.copy()
        last_index_ts = index_ts
    return actions
Ejemplo n.º 47
0
def gl_bn_model_fixture():
    cm_model = contextmanager(gl_bn_model)
    with cm_model() as model:
        yield model
Ejemplo n.º 48
0
 async def fixture(unused_port_factory):
     make_server = contextlib.contextmanager(_server_fixture(server))
     with make_server(unused_port_factory) as port:
         async with channel(f"localhost:{port}") as chan:
             yield stub(chan)
Ejemplo n.º 49
0
    instance = cluster.instances['ch3']

    ddl_check_query(
        instance,
        "DROP TABLE IF EXISTS test_optimize ON CLUSTER cluster FORMAT TSV")
    ddl_check_query(
        instance,
        "CREATE TABLE test_optimize ON CLUSTER cluster (p Date, i Int32) ENGINE = MergeTree(p, p, 8192)"
    )
    ddl_check_query(
        instance, "OPTIMIZE TABLE test_optimize ON CLUSTER cluster FORMAT TSV")


def test_create_as_select(started_cluster):
    instance = cluster.instances['ch2']
    ddl_check_query(
        instance,
        "CREATE TABLE test_as_select ON CLUSTER cluster ENGINE = Memory AS (SELECT 1 AS x UNION ALL SELECT 2 AS x)"
    )
    assert TSV(instance.query(
        "SELECT x FROM test_as_select ORDER BY x")) == TSV("1\n2\n")
    ddl_check_query(instance,
                    "DROP TABLE IF EXISTS test_as_select ON CLUSTER cluster")


if __name__ == '__main__':
    with contextmanager(started_cluster)() as cluster:
        for name, instance in cluster.instances.items():
            print name, instance.ip_address
        raw_input("Cluster created, press any key to destroy...")
Ejemplo n.º 50
0
                'hazards': [5, 9]
            }
        }
        object = sampledb.logic.objects.create_object(
            action_id=instrument_action.id,
            data=data,
            user_id=user.id,
            previous_object_id=None,
            schema=schema)

        os.makedirs('docs/static/img/generated', exist_ok=True)
        options = Options()
        # disable Chrome sandbox for root in GitLab CI
        if 'CI' in os.environ and getpass.getuser() == 'root':
            options.add_argument('--no-sandbox')
        with contextlib.contextmanager(
                tests.conftest.create_flask_server)(app) as flask_server:
            with contextlib.closing(Chrome(options=options)) as driver:
                time.sleep(5)
                object_permissions(flask_server.base_url, driver)
                default_permissions(flask_server.base_url, driver)
                guest_invitation(flask_server.base_url, driver)
                action(flask_server.base_url, driver, instrument_action)
                hazards_input(flask_server.base_url, driver, instrument_action)
                tags_input(flask_server.base_url, driver, object)
                comments(flask_server.base_url, driver, object)
                activity_log(flask_server.base_url, driver, object)
                files(flask_server.base_url, driver, object)
                file_information(flask_server.base_url, driver, object)
                labels(flask_server.base_url, driver, object)
                advanced_search_by_property(flask_server.base_url, driver,
                                            object)
Ejemplo n.º 51
0
def get_build_index(config, subdir, clear_cache=False, omit_defaults=False):
    global local_index_timestamp
    global local_subdir
    global cached_index
    global cached_channels
    log = utils.get_logger(__name__)
    mtime = 0

    if config.output_folder:
        output_folder = config.output_folder
    else:
        output_folder = os.path.dirname(config.bldpkgs_dir)

    # check file modification time - this is the age of our index.
    index_file = os.path.join(output_folder, subdir, 'repodata.json')
    if os.path.isfile(index_file):
        mtime = os.path.getmtime(index_file)

    if (clear_cache or not os.path.isfile(index_file) or local_subdir != subdir
            or mtime > local_index_timestamp
            or cached_channels != config.channel_urls):
        log.debug(
            "Building new index for subdir '{}' with channels {}, condarc channels "
            "= {}".format(subdir, config.channel_urls, not omit_defaults))
        # priority: local by croot (can vary), then channels passed as args,
        #     then channels from config.
        capture = contextlib.contextmanager(lambda: (yield))
        if config.debug:
            log_context = partial(utils.LoggingContext, logging.DEBUG)
        elif config.verbose:
            log_context = partial(utils.LoggingContext, logging.INFO)
        else:
            log_context = partial(utils.LoggingContext, logging.CRITICAL + 1)
            capture = utils.capture

        urls = list(config.channel_urls)
        if os.path.isdir(output_folder):
            urls.insert(0, url_path(output_folder))
        ensure_valid_channel(output_folder, subdir, config)

        # silence output from conda about fetching index files
        with log_context():
            with capture():
                # replace noarch with native subdir - this ends up building an index with both the
                #      native content and the noarch content.
                if subdir == 'noarch':
                    subdir = conda_interface.subdir
                try:
                    cached_index = get_index(channel_urls=urls,
                                             prepend=not omit_defaults,
                                             use_local=False,
                                             use_cache=False,
                                             platform=subdir)
                # HACK: defaults does not have the many subfolders we support.  Omit it and
                #          try again.
                except CondaHTTPError:
                    if 'defaults' in urls:
                        urls.remove('defaults')
                    cached_index = get_index(channel_urls=urls,
                                             prepend=omit_defaults,
                                             use_local=False,
                                             use_cache=False,
                                             platform=subdir)
        local_index_timestamp = mtime
        local_subdir = subdir
        cached_channels = config.channel_urls
    return cached_index, local_index_timestamp
Ejemplo n.º 52
0
 def wrapper(patcher):
     patcher = contextmanager(patcher)
     cls.INTENSIVE_CALLS_PATCHER[metric_name] = patcher
     return patcher
Ejemplo n.º 53
0
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, declarative_base

import settings


engine = create_engine(
    'postgresql://{}:{}@{}:{}/{}'.format(
        settings.global_settings.db_user,
        settings.global_settings.db_password,
        settings.global_settings.db_host,
        settings.global_settings.db_port,
        settings.global_settings.db_name,
    )
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)

Base = declarative_base()


def get_session():
    session = SessionLocal()
    try:
        yield session
    finally:
        session.close()


context_session = contextmanager(get_session)
Ejemplo n.º 54
0
def eg_bn_model(request):
    eg_bn_model = eg_bn_model_factory(request)

    cm_model = contextmanager(eg_bn_model)
    with cm_model() as model:
        yield model
Ejemplo n.º 55
0
        self.cur = cur

    def __enter__(self):
        # on enter create the generator
        self.gen = temptable(self.cur)
        next(self.gen)

    def __exit__(self, *args):
        # on exit go back into method, drop table
        # and then return None
        next(self.gen, None)


with connect('test.db') as conn:
    cur = conn.cursor()
    with contextmanager(cur):
        cur.execute('insert into points (x, y) values (3, 3)')
        cur.execute('insert into points (x, y) values (3, 2)')
        cur.execute('insert into points (x, y) values (3, 1)')
        cur.execute('insert into points (x, y) values (3, 4)')
        for row in cur.execute('select x, y from points'):
            print(row)
        for row in cur.execute('select sum(x * y) from points'):
            print(row)


# a more general solution is to pass the generator
# into the __init__
class contextmanager:
    def __init__(self, gen):
        self.gen = gen
Ejemplo n.º 56
0
    def enter(self, fixtures, setup_callbacks=None, teardown_callbacks=None, stop_setups=False):
        """Transform each fixture_method into a context manager, enter them
        recursively, and yield any failures.

        `stop_setups` is set after a setup fixture fails. This flag prevents
        more setup fixtures from being added to the onion after a failure as we
        recurse through the list of fixtures.
        """

        # base case
        if not fixtures:
            yield []
            return

        setup_callbacks = setup_callbacks or [None, None]
        teardown_callbacks = teardown_callbacks or [None, None]

        fixture = fixtures[0]

        ctm = contextlib.contextmanager(fixture)()

        # class_teardown fixture is wrapped as
        # class_setup_teardown. We should not fire events for the
        # setup phase of this fake context manager.
        suppress_callbacks = bool(fixture._fixture_type in TEARDOWN_FIXTURES)

        # if a previous setup fixture failed, stop running new setup
        # fixtures.  this doesn't apply to teardown fixtures, however,
        # because behind the scenes they're setup_teardowns, and we need
        # to run the (empty) setup portion in order to get the teardown
        # portion later.
        if not stop_setups or fixture._fixture_type in TEARDOWN_FIXTURES:
            enter_failures = self.run_fixture(
                fixture,
                ctm.__enter__,
                enter_callback=None if suppress_callbacks else setup_callbacks[0],
                exit_callback=None if suppress_callbacks else setup_callbacks[1],
            )
            # keep skipping setups once we've had a failure
            stop_setups = stop_setups or bool(enter_failures)
        else:
            # we skipped the setup, pretend like nothing happened.
            enter_failures = []

        with self.enter(fixtures[1:], setup_callbacks, teardown_callbacks, stop_setups) as all_failures:
            all_failures += enter_failures or []
            # need to only yield one failure
            yield all_failures

        # this setup fixture got skipped due to an earlier setup fixture
        # failure, or failed itself. all of these fixtures are basically
        # represented by setup_teardowns, but because we never ran this setup,
        # we have nothing to do for teardown (if we did visit it here, that
        # would have the effect of running the setup we just skipped), so
        # instead bail out and move on to the next fixture on the stack.
        if stop_setups and fixture._fixture_type in SETUP_FIXTURES:
            return

        # class_setup fixture is wrapped as
        # class_setup_teardown. We should not fire events for the
        # teardown phase of this fake context manager.
        suppress_callbacks = bool(fixture._fixture_type in SETUP_FIXTURES)

        # this is hack to finish the remainder of the context manager without
        # calling contextlib's __exit__; doing that messes up the stack trace
        # we end up with.
        def exit():
            try:
                ctm.gen.next()
            except StopIteration:
                pass

        exit_failures = self.run_fixture(
            fixture,
            exit,
            enter_callback=None if suppress_callbacks else teardown_callbacks[0],
            exit_callback=None if suppress_callbacks else teardown_callbacks[1],
        )

        all_failures += exit_failures or []
Ejemplo n.º 57
0
    def daily_stats(self, data=sample_data):
        def fake_fetch():
            yield iter(data)

        self._fetch_mock.side_effect = contextmanager(fake_fetch)
        return readers.jhucsse_reader.daily_stats()
Ejemplo n.º 58
0
def run(dataset, config):
    log.info(f"\n**** H2O AutoML [v{h2o.__version__}] ****\n")
    save_metadata(config, version=h2o.__version__)
    # Mapping of benchmark metrics to H2O metrics
    metrics_mapping = dict(acc='mean_per_class_error',
                           auc='AUC',
                           logloss='logloss',
                           mae='mae',
                           mse='mse',
                           r2='r2',
                           rmse='rmse',
                           rmsle='rmsle')
    sort_metric = metrics_mapping[
        config.metric] if config.metric in metrics_mapping else None
    if sort_metric is None:
        # TODO: Figure out if we are going to blindly pass metrics through, or if we use a strict mapping
        log.warning("Performance metric %s not supported, defaulting to AUTO.",
                    config.metric)

    try:
        training_params = {
            k: v
            for k, v in config.framework_params.items()
            if not k.startswith('_')
        }
        nthreads = config.framework_params.get('_nthreads', config.cores)
        jvm_memory = str(
            round(config.max_mem_size_mb * 2 /
                  3)) + "M"  # leaving 1/3rd of available memory for XGBoost

        log.info("Starting H2O cluster with %s cores, %s memory.", nthreads,
                 jvm_memory)
        max_port_range = 49151
        min_port_range = 1024
        rnd_port = os.getpid() % (max_port_range -
                                  min_port_range) + min_port_range
        port = config.framework_params.get('_port', rnd_port)

        init_params = config.framework_params.get('_init', {})
        if "logs" in config.framework_params.get('_save_artifacts', []):
            init_params['ice_root'] = output_subdir("logs", config)

        h2o.init(nthreads=nthreads,
                 port=port,
                 min_mem_size=jvm_memory,
                 max_mem_size=jvm_memory,
                 **init_params)

        import_kwargs = {}
        # Load train as an H2O Frame, but test as a Pandas DataFrame
        log.debug("Loading train data from %s.", dataset.train.path)
        train = None
        if version.parse(h2o.__version__) >= version.parse(
                "3.32.0.3"
        ):  # previous versions may fail to parse correctly some rare arff files using single quotes as enum/string delimiters (pandas also fails on same datasets)
            import_kwargs['quotechar'] = '"'
            train = h2o.import_file(dataset.train.path,
                                    destination_frame=frame_name(
                                        'train', config),
                                    **import_kwargs)
            if not verify_loaded_frame(train, dataset):
                h2o.remove(train)
                train = None
                import_kwargs['quotechar'] = "'"

        if not train:
            train = h2o.import_file(dataset.train.path,
                                    destination_frame=frame_name(
                                        'train', config),
                                    **import_kwargs)
            # train.impute(method='mean')
        log.debug("Loading test data from %s.", dataset.test.path)
        test = h2o.import_file(dataset.test.path,
                               destination_frame=frame_name('test', config),
                               **import_kwargs)
        # test.impute(method='mean')

        log.info("Running model on task %s, fold %s.", config.name,
                 config.fold)
        log.debug(
            "Running H2O AutoML with a maximum time of %ss on %s core(s), optimizing %s.",
            config.max_runtime_seconds, config.cores, sort_metric)

        aml = H2OAutoML(max_runtime_secs=config.max_runtime_seconds,
                        sort_metric=sort_metric,
                        seed=config.seed,
                        **training_params)

        monitor = (
            BackendMemoryMonitoring(
                frequency_seconds=config.ext.monitoring.frequency_seconds,
                check_on_exit=True,
                verbosity=config.ext.monitoring.verbosity)
            if config.framework_params.get('_monitor_backend', False)
            # else contextlib.nullcontext  # Py 3.7+ only
            else contextlib.contextmanager(iter)([0]))
        with utils.Timer() as training:
            with monitor:
                aml.train(y=dataset.target.index, training_frame=train)

        if not aml.leader:
            raise FrameworkError(
                "H2O could not produce any model in the requested time.")

        with utils.Timer() as predict:
            preds = aml.predict(test)

        preds = extract_preds(preds, test, dataset=dataset)
        save_artifacts(aml, dataset=dataset, config=config)

        return result(output_file=config.output_predictions_file,
                      predictions=preds.predictions,
                      truth=preds.truth,
                      probabilities=preds.probabilities,
                      probabilities_labels=preds.probabilities_labels,
                      models_count=len(aml.leaderboard),
                      training_duration=training.duration,
                      predict_duration=predict.duration)

    finally:
        if h2o.connection():
            # h2o.remove_all()
            h2o.connection().close()
        if h2o.connection().local_server:
            h2o.connection().local_server.shutdown()
Ejemplo n.º 59
0
 def wrapper(*args, **kwargs):
     with contextmanager(disabler)():
         return func(*args, **kwargs)
Ejemplo n.º 60
0
class CorrectnessTest(keras_parameterized.TestCase):
    @keras_parameterized.run_with_all_model_types
    @keras_parameterized.run_all_keras_modes
    def test_loss_correctness(self):
        # Test that training loss is the same in eager and graph
        # (by comparing it to a reference value in a deterministic case)
        layers = [
            keras.layers.Dense(3, activation='relu',
                               kernel_initializer='ones'),
            keras.layers.Dense(2,
                               activation='softmax',
                               kernel_initializer='ones')
        ]
        model = testing_utils.get_model_from_layers(layers, input_shape=(4, ))
        model.compile(loss='sparse_categorical_crossentropy',
                      optimizer=rmsprop.RMSprop(learning_rate=0.001),
                      run_eagerly=testing_utils.should_run_eagerly(),
                      experimental_run_tf_function=testing_utils.
                      should_run_tf_function())
        x = np.ones((100, 4))
        np.random.seed(123)
        y = np.random.randint(0, 1, size=(100, 1))
        history = model.fit(x, y, epochs=1, batch_size=10)
        self.assertAlmostEqual(history.history['loss'][-1], 0.5836, 4)

    @keras_parameterized.run_with_all_model_types
    @keras_parameterized.run_all_keras_modes
    def test_loss_correctness_with_iterator(self):
        # Test that training loss is the same in eager and graph
        # (by comparing it to a reference value in a deterministic case)
        layers = [
            keras.layers.Dense(3, activation='relu',
                               kernel_initializer='ones'),
            keras.layers.Dense(2,
                               activation='softmax',
                               kernel_initializer='ones')
        ]
        model = testing_utils.get_model_from_layers(layers, input_shape=(4, ))
        model.compile(loss='sparse_categorical_crossentropy',
                      optimizer=rmsprop.RMSprop(learning_rate=0.001),
                      run_eagerly=testing_utils.should_run_eagerly(),
                      experimental_run_tf_function=testing_utils.
                      should_run_tf_function())
        x = np.ones((100, 4), dtype=np.float32)
        np.random.seed(123)
        y = np.random.randint(0, 1, size=(100, 1))
        dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
        dataset = dataset.repeat(100)
        dataset = dataset.batch(10)
        history = model.fit(dataset, epochs=1, steps_per_epoch=10)
        self.assertAlmostEqual(history.history['loss'][-1], 0.5836, 4)

    def test_loss_in_call(self):
        class HasLoss(keras.layers.Layer):
            def call(self, x):
                self.add_loss(x)
                return x

        layer = HasLoss()
        layer(1.)  # Plain-value inputs are only valid in eager mode.
        self.assertEqual(1, len(layer.losses))

    @parameterized.named_parameters([
        ('_None', contextlib.contextmanager(lambda: iter([None])), 0., 4.),
        ('_0', lambda: keras.backend.learning_phase_scope(0), 4., 4.),
        ('_1', lambda: keras.backend.learning_phase_scope(1), 0., 0.),
    ])
    def test_nested_model_learning_phase(self, nested_scope_fn,
                                         expected_training_loss,
                                         expected_validation_loss):
        """Tests that learning phase is correctly set in an intermediate layer."""
        def _make_unregularized_model():
            inputs = keras.Input((4, ))
            # Zero out activations when `training=True`.
            x = keras.layers.Dropout(1. - 1. / (1 << 24))(inputs)
            x = keras.layers.Dense(
                10,
                activation='relu',
                trainable=False,
                bias_initializer='zeros',
                kernel_initializer='ones')(
                    x)  # Just sum together all the activations.
            outputs = keras.layers.Dense(3)(x)
            return keras.Model(inputs, outputs)

        def _regularize_model(unregularized_model):
            inputs = keras.Input(unregularized_model.inputs[0].shape[1:])
            with nested_scope_fn():
                logits = unregularized_model(inputs)
            outputs = keras.activations.softmax(logits)
            model = keras.Model(inputs, outputs)
            # Regularize the most recent activations of a post-dropout layer.
            sample_activations = unregularized_model.get_layer(
                index=-2).get_output_at(-1)
            regularization_loss = keras.backend.mean(sample_activations)
            model.add_loss(regularization_loss)
            model.add_metric(regularization_loss,
                             aggregation='mean',
                             name='regularization_loss')
            return model

        # Make and compile models.
        model = _regularize_model(_make_unregularized_model())
        model.compile('sgd', 'sparse_categorical_crossentropy')
        # Prepare fake data.
        x = np.ones((20, 4)).astype(np.float32)
        y = np.random.randint(0, 3, size=(20, )).astype(np.int64)
        dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(2)
        evaluation_results = dict(
            zip(model.metrics_names, model.evaluate(dataset)))
        # Rate of dropout depends on the learning phase.
        self.assertEqual(evaluation_results['regularization_loss'],
                         expected_validation_loss)
        history = model.fit(dataset, epochs=2, validation_data=dataset).history
        self.assertAllEqual(history['regularization_loss'],
                            [expected_training_loss] * 2)
        self.assertAllEqual(history['val_regularization_loss'],
                            [expected_validation_loss] * 2)