示例#1
0
文件: emitter.py 项目: pektin/jam
def Function_emitInstructions(self):
    self.emitEntry()

    # Exit stack that might contain a selfScope
    self_stack = ExitStack()
    if "self" in self.closed_context:
        assert self.closed_context["self"].llvm_context_index >= 0
        index = self.closed_context["self"].llvm_context_index
        self_value = State.builder.structGEP(self.llvm_context, index, "")
        self_value = State.builder.load(self_value, "")
        self_stack.enter_context(State.selfScope(self_value))

    with self_stack:
        for object in self.closed_context:
            if object.name == "self": continue
            if object.stats.static: continue

            index = object.llvm_context_index
            object.llvm_value = State.builder.structGEP(self.llvm_context, index, "")

        # Allocate Arguments
        for index, arg in enumerate(self.arguments):
            val = self.llvm_value.getParam(index + 1)
            arg.llvm_value = State.builder.alloca(arg.resolveType().emitType(), resolveName(arg))
            State.builder.store(val, arg.llvm_value)

        self.emitPostContext()

        return State.emitInstructions(self.instructions)
示例#2
0
    def setUp(self):
        super(TestAzureEndpointHttpClient, self).setUp()
        patches = ExitStack()
        self.addCleanup(patches.close)

        self.read_file_or_url = patches.enter_context(
            mock.patch.object(azure_helper.util, 'read_file_or_url'))
示例#3
0
class TestZip(unittest.TestCase):
    def setUp(self):
        # Find the path to the example.*.whl so we can add it to the front of
        # sys.path, where we'll then try to find the metadata thereof.
        self.resources = ExitStack()
        self.addCleanup(self.resources.close)
        wheel = self.resources.enter_context(
            path('importlib_metadata.tests.data',
                 'example-21.12-py3-none-any.whl'))
        sys.path.insert(0, str(wheel))
        self.resources.callback(sys.path.pop, 0)

    def test_zip_version(self):
        self.assertEqual(importlib_metadata.version('example'), '21.12')

    def test_zip_entry_points(self):
        parser = importlib_metadata.entry_points('example')
        entry_point = parser.get('console_scripts', 'example')
        self.assertEqual(entry_point, 'example:main')

    def test_missing_metadata(self):
        distribution = importlib_metadata.distribution('example')
        self.assertIsNone(distribution.read_text('does not exist'))

    def test_case_insensitive(self):
        self.assertEqual(importlib_metadata.version('Example'), '21.12')
示例#4
0
    def setUp(self):
        super(TestFindEndpoint, self).setUp()
        patches = ExitStack()
        self.addCleanup(patches.close)

        self.load_file = patches.enter_context(
            mock.patch.object(azure_helper.util, 'load_file'))
示例#5
0
def logtail(path, offset_path=None, *, dry_run=False):
    """Yield new lines from a logfile.

    :param path: The path to the file to read from
    :param offset_path: The path to the file where offset/inode
                        information will be stored.  If not set,
                        ``<file>.offset`` will be used.
    :param dry_run: If ``True``, the offset file will not be modified
                    or created.
    """
    if offset_path is None:
        offset_path = path + '.offset'

    try:
        logfile = open(path, encoding='utf-8', errors='replace')
    except OSError as exc:
        warning_echo('Could not read: {} ({})'.format(path, exc))
        return

    closer = ExitStack()
    closer.enter_context(logfile)
    with closer:
        line_iter = iter([])
        stat = os.stat(logfile.fileno())
        debug_echo('logfile inode={}, size={}'.format(stat.st_ino, stat.st_size))
        inode, offset = _parse_offset_file(offset_path)
        if inode is not None:
            if stat.st_ino == inode:
                debug_echo('inodes are the same')
                if offset == stat.st_size:
                    debug_echo('offset points to eof')
                    return
                elif offset > stat.st_size:
                    warning_echo('File shrunk since last read: {} ({} < {})'.format(path, stat.st_size, offset))
                    offset = 0
            else:
                debug_echo('inode changed, checking for rotated file')
                rotated_path = _check_rotated_file(path, inode)
                if rotated_path is not None:
                    try:
                        rotated_file = open(rotated_path, encoding='utf-8', errors='replace')
                    except OSError as exc:
                        warning_echo('Could not read rotated file: {} ({})'.format(rotated_path, exc))
                    else:
                        closer.enter_context(rotated_file)
                        rotated_file.seek(offset)
                        line_iter = itertools.chain(line_iter, iter(rotated_file))
                offset = 0
        logfile.seek(offset)
        line_iter = itertools.chain(line_iter, iter(logfile))
        for line in line_iter:
            line = line.strip()
            yield line
        pos = logfile.tell()
        debug_echo('reached end of logfile at {}'.format(pos))
        if not dry_run:
            debug_echo('writing offset file: ' + offset_path)
            _write_offset_file(offset_path, stat.st_ino, pos)
        else:
            debug_echo('dry run - not writing offset file')
示例#6
0
    def setUp(self):
        super(TestWalkerHandleHandler, self).setUp()
        tmpdir = tempfile.mkdtemp()
        self.addCleanup(shutil.rmtree, tmpdir)

        self.data = {
            "handlercount": 0,
            "frequency": "",
            "handlerdir": tmpdir,
            "handlers": helpers.ContentHandlers(),
            "data": None,
        }

        self.expected_module_name = "part-handler-%03d" % (self.data["handlercount"],)
        expected_file_name = "%s.py" % self.expected_module_name
        self.expected_file_fullname = os.path.join(self.data["handlerdir"], expected_file_name)
        self.module_fake = FakeModule()
        self.ctype = None
        self.filename = None
        self.payload = "dummy payload"

        # Mock the write_file() function.  We'll assert that it got called as
        # expected in each of the individual tests.
        resources = ExitStack()
        self.addCleanup(resources.close)
        self.write_file_mock = resources.enter_context(mock.patch("cloudinit.util.write_file"))
示例#7
0
文件: closure.py 项目: pektin/jam
    def target(self):
        stack = ExitStack()

        for target, value in self.targets:
            stack.enter_context(target.resolveValue().targetValue(value))

        with stack:
            yield
示例#8
0
文件: closure.py 项目: pektin/jam
    def checkCompatibility(self, other, check_cache = None):
        stack = ExitStack()
        stack.enter_context(self.target())
        if isinstance(other, ClosedTarget):
            stack.enter_context(other.target())
            other = other.value

        with stack:
            return self.value.checkCompatibility(other, check_cache)
示例#9
0
class Fixture(object):

    def __init__(self, goal, workspace):
        self.fixtures = ExitStack()
        self.goal = goal
        self.workspace = workspace

    def __exit__(self, *args):
        self.fixtures.close()
示例#10
0
    def test_instance_bypass(self):
        class Example(object):
            pass

        cm = Example()
        cm.__exit__ = object()
        stack = ExitStack()
        self.assertRaises(AttributeError, stack.enter_context, cm)
        stack.push(cm)
        self.assertIs(tuple(stack._exit_callbacks)[-1], cm)
class EventLoopThread(object):
    def __init__(self, servers_to_start):
        self.context = None
        self.executor = None
        self.loop = None
        self.servers_to_start = servers_to_start
        self.servers = []

    def __enter__(self):
        self.context = ExitStack()
        self.executor = self.context.enter_context(ThreadPoolExecutor(max_workers=1))
        self.context.enter_context(self.event_loop_context())
        return self

    def __exit__(self, *enc):
        self.context.__exit__(*enc)
        self.context = None
        self.executor = None
        self.loop = None

    def start_loop(self, event):
        logger.info("starting eventloop server")
        loop = asyncio.new_event_loop()
        self.loop = loop
        asyncio.set_event_loop(loop)
        for server_starter in self.servers_to_start:
            server = loop.run_until_complete(server_starter)
            self.servers.append(server)
        loop.call_soon(event.set)
        loop.run_forever()

    def stop_loop(self):
        logger.info("stopping eventloop server")
        self.loop.create_task(self._close_connections())

    @contextmanager
    def event_loop_context(self):
        event = Event()
        event.clear()
        self.executor.submit(self.start_loop, event)
        event.wait()
        logger.info("started eventloop")
        try:
            yield
        finally:
            self.loop.call_soon_threadsafe(self.stop_loop)
            logger.info("stopped eventloop")

    @asyncio.coroutine
    def _close_connections(self):
        for server in self.servers:
            server.close()
            yield from server.wait_closed()
        self.loop.stop()
示例#12
0
 def mock_sys(self):
     "Mock system environment for InteractiveConsole"
     # use exit stack to match patch context managers to addCleanup
     stack = ExitStack()
     self.addCleanup(stack.close)
     self.infunc = stack.enter_context(mock.patch("code.input", create=True))
     self.stdout = stack.enter_context(mock.patch("code.sys.stdout"))
     self.stderr = stack.enter_context(mock.patch("code.sys.stderr"))
     prepatch = mock.patch("code.sys", wraps=code.sys, spec=code.sys)
     self.sysmod = stack.enter_context(prepatch)
     if sys.excepthook is sys.__excepthook__:
         self.sysmod.excepthook = self.sysmod.__excepthook__
示例#13
0
 class StreamDecoder:
     def __init__(self, file):
         self._file = file
         self._crc = 0
         self._pipe = PipeWriter()
         self._cleanup = ExitStack()
         coroutine = self._pipe.coroutine(self._receive())
         self._cleanup.enter_context(coroutine)
     
     def close(self):
         self._pipe.close()
         del self._pipe
         self._cleanup.close()
     
     def feed(self, data):
         self._pipe.write(data)
     
     def _receive(self):
         while True:
             data = self._pipe.buffer
             pos = data.find(b"=")
             if pos >= 0:
                 data = data[:pos]
             data = data.replace(b"\r", b"").replace(b"\n", b"")
             data = data.translate(self.TABLE)
             # TODO: check data size overflow
             self._crc = crc32(data, self._crc)
             self._file.write(data)
             if pos >= 0:  # Escape character (equals sign)
                 self._pipe.buffer = self._pipe.buffer[pos + 1:]
                 while True:
                     byte = yield from self._pipe.read_one()
                     if byte not in b"\r\n":
                         break
                 # TODO: check for size overflow
                 [byte] = byte
                 data = bytes(((byte - 64 - 42) & bitmask(8),))
                 self._crc = crc32(data, self._crc)
                 self._file.write(data)
             else:
                 try:
                     self._pipe.buffer = yield
                 except EOFError:
                     break
     
     def flush(self):
         pass
     
     def getCrc32(self):
         return format(self._crc, "08x")
     
     TABLE = bytes(range(256))
     TABLE = TABLE[-42:] + TABLE[:-42]
示例#14
0
    def setUp(self):
        super(TestOpenSSLManager, self).setUp()
        patches = ExitStack()
        self.addCleanup(patches.close)

        self.subp = patches.enter_context(
            mock.patch.object(azure_helper.util, 'subp'))
        try:
            self.open = patches.enter_context(
                mock.patch('__builtin__.open'))
        except ImportError:
            self.open = patches.enter_context(
                mock.patch('builtins.open'))
示例#15
0
文件: variable.py 项目: pektin/jam
    def targetValue(self, value):
        value = value.resolveValue()
        old_value = self.value
        self.value = value

        stack = ExitStack()
        if self._static_value_type is not None:
            targets = [(self._static_value_type, value)]
            stack.enter_context(forward.target(targets))

        with stack:
            yield

        self.value = old_value
示例#16
0
class TestSocket(unittest.TestCase):
    # Usually the socket will be set up from socket.getaddrinfo() but if that
    # raises socket.gaierror, then it tries to infer the IPv4/IPv6 type from
    # the host name.
    def setUp(self):
        self._resources = ExitStack()
        self.addCleanup(self._resources.close)
        self._resources.enter_context(patch('aiosmtpd.main.socket.getaddrinfo',
                                            side_effect=socket.gaierror))

    def test_ipv4(self):
        bind = self._resources.enter_context(patch('aiosmtpd.main.bind'))
        mock_sock = setup_sock('host.example.com', 8025)
        bind.assert_called_once_with(socket.AF_INET, socket.SOCK_STREAM, 0)
        mock_sock.bind.assert_called_once_with(('host.example.com', 8025))

    def test_ipv6(self):
        bind = self._resources.enter_context(patch('aiosmtpd.main.bind'))
        mock_sock = setup_sock('::1', 8025)
        bind.assert_called_once_with(socket.AF_INET6, socket.SOCK_STREAM, 0)
        mock_sock.bind.assert_called_once_with(('::1', 8025, 0, 0))

    def test_bind_ipv4(self):
        self._resources.enter_context(patch('aiosmtpd.main.socket.socket'))
        mock_sock = setup_sock('host.example.com', 8025)
        mock_sock.setsockopt.assert_called_once_with(
            socket.SOL_SOCKET, socket.SO_REUSEADDR, True)

    def test_bind_ipv6(self):
        self._resources.enter_context(patch('aiosmtpd.main.socket.socket'))
        mock_sock = setup_sock('::1', 8025)
        self.assertEqual(mock_sock.setsockopt.call_args_list, [
            call(socket.SOL_SOCKET, socket.SO_REUSEADDR, True),
            call(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False),
            ])
示例#17
0
def context_env_update(context_list, env):
    es = ExitStack()
    for item in context_list:
        # create context manager and enter
        tmp_name = '__pw_cm'
        cm_code = compile(ast.Expression(item.context_expr), '<context_eval>', 'eval')
        env[tmp_name] = es.enter_context(eval(cm_code, env))

        # assign to its optional_vars in separte dict
        if item.optional_vars:
            code = assign_from_ast(item.optional_vars, tmp_name)
            exec(code, env)

    return es
示例#18
0
class PulpWritter(object):
    """Use this to create a pulp db."""
    def __init__(self, db_name, msg_dumper=None, idx_dumpers=None):
        self.dir_path = os.path.abspath(db_name)
        self.keys_path = os.path.join(self.dir_path, 'keys')
        if os.path.isdir(self.dir_path):
            shutil.rmtree(self.dir_path)
        os.makedirs(self.dir_path)
        os.makedirs(self.keys_path)

        self.master_table = None
        self.key_tables = {}
        self.table_stack = None
        self.msg_dumper = msg_dumper
        self.idx_dumpers = idx_dumpers
    
    def __enter__(self):
        self.table_stack = ExitStack().__enter__()
        table = MasterTable(self.dir_path, 'w', dumper=self.msg_dumper)
        self.add_table(None, table)
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if self.table_stack is not None:
            self.table_stack.close()

    def add_table(self, key, table):
        if key == None:
            self.master_table = table
        else:
            self.key_tables[key] = table
        self.table_stack.enter_context(table)

    def append(self, data, index_map):
        msg_num = self.master_table.append(data)
        for key, value in index_map.items():
            table = self.key_tables.get(key, None)
            if not table:
                if self.idx_dumpers is not None:
                    dumper = self.idx_dumpers.get(key, None)
                else:
                    dumper = None
                table = KeyTable(self.keys_path, key, 'w', dumper=dumper)
                self.add_table(key, table)

            if isinstance(value, (tuple, list, set)):
                for v in value:
                    table.append(v, msg_num)
            else:
                table.append(value, msg_num)
示例#19
0
 def startTestRun(self, event):
     # Create a mock for the `sudo snap prepare-image` command.  This is an
     # expensive command which hits the actual snap store.  We want this to
     # run at least once so we know our tests are valid.  We can cache the
     # results in a test-suite-wide temporary directory and simulate future
     # calls by just recursively copying the contents to the specified
     # directories.
     #
     # It's a bit more complicated than that though, because it's possible
     # that the channel and model.assertion will be different, so we need
     # to make the cache dependent on those values.
     #
     # Finally, to enable full end-to-end tests, check an environment
     # variable to see if the mocking should even be done.  This way, we
     # can make our Travis-CI job do at least one real end-to-end test.
     self.resources = ExitStack()
     # How should we mock `snap prepare-image`?  If set to 'always' (case
     # insensitive), then use the sample data in the .zip file.  Any other
     # truthy value says to use a second-and-onward mock.
     should_we_mock = os.environ.get('UBUNTU_IMAGE_MOCK_SNAP', 'yes')
     if should_we_mock.lower() == 'always':
         mock_class = AlwaysMock
     elif as_bool(should_we_mock):
         mock_class = SecondAndOnwardMock
     else:
         mock_class = None
     if mock_class is not None:
         tmpdir = self.resources.enter_context(TemporaryDirectory())
         # Record the actual snap mocker on the class so that other tests
         # can temporarily disable it.  Some tests need to run the actual
         # snap() helper function.
         self.__class__.snap_mocker = self.resources.enter_context(
             mock_class(tmpdir))
示例#20
0
 def __init__(self, file):
     self._file = file
     self._crc = 0
     self._pipe = PipeWriter()
     self._cleanup = ExitStack()
     coroutine = self._pipe.coroutine(self._receive())
     self._cleanup.enter_context(coroutine)
示例#21
0
 def __init__(self):
     # Variables which manage state transitions.
     self._next = deque()
     self._debug_step = 1
     # Manage all resources so they get cleaned up whenever the state
     # machine exits for any reason.
     self.resources = ExitStack()
    def __enter__(self):
        # Registration doesn't work unless we're logged out.
        self.test_client.logout()
        # Now, post the registration.
        self.test_client.validate('default/user/register',
            'Support Runestone Interactive' if self.is_free else 'Payment Amount',
            data=dict(
                username=self.username,
                first_name=self.first_name,
                last_name=self.last_name,
                # The e-mail address must be unique.
                email=self.email,
                password=self.password,
                password_two=self.password,
                # Note that ``course_id`` is (on the form) actually a course name.
                course_id=self.course_name,
                accept_tcp='on',
                donate='0',
                _next='/runestone/default/index',
                _formname='register',
            )
        )

        # Schedule this user for deletion.
        self.exit_stack_object = ExitStack()
        self.exit_stack = self.exit_stack_object.__enter__()
        self.exit_stack.callback(self._delete_user)

        # Record IDs
        db = self.runestone_db_tools.db
        self.course_id = db(db.courses.course_name == self.course_name).select(db.courses.id).first().id
        self.user_id = db(db.auth_user.username == self.username).select(db.auth_user.id).first().id

        return self
示例#23
0
文件: utils.py 项目: lagka/sockeye
def determine_context(device_ids: List[int],
                      use_cpu: bool,
                      disable_device_locking: bool,
                      lock_dir: str,
                      exit_stack: ExitStack) -> List[mx.Context]:
    """
    Determine the MXNet context to run on (CPU or GPU).

    :param device_ids: List of device as defined from the CLI.
    :param use_cpu: Whether to use the CPU instead of GPU(s).
    :param disable_device_locking: Disable Sockeye's device locking feature.
    :param lock_dir: Directory to place device lock files in.
    :param exit_stack: An ExitStack from contextlib.
    :return: A list with the context(s) to run on.
    """
    if use_cpu:
        context = [mx.cpu()]
    else:
        num_gpus = get_num_gpus()
        check_condition(num_gpus >= 1,
                        "No GPUs found, consider running on the CPU with --use-cpu "
                        "(note: check depends on nvidia-smi and this could also mean that the nvidia-smi "
                        "binary isn't on the path).")
        if disable_device_locking:
            context = expand_requested_device_ids(device_ids)
        else:
            context = exit_stack.enter_context(acquire_gpus(device_ids, lock_dir=lock_dir))
        context = [mx.gpu(gpu_id) for gpu_id in context]
    return context
 def setUp(self):
     super(TestCloudStackPasswordFetching, self).setUp()
     self.patches = ExitStack()
     self.addCleanup(self.patches.close)
     mod_name = 'cloudinit.sources.DataSourceCloudStack'
     self.patches.enter_context(mock.patch('{0}.ec2'.format(mod_name)))
     self.patches.enter_context(mock.patch('{0}.uhelp'.format(mod_name)))
示例#25
0
文件: rate_limiter.py 项目: jmb/geopy
 def setUp(self):
     self._stack = ExitStack()
     self.mock_clock = self._stack.enter_context(
         patch.object(RateLimiter, '_clock'))
     self.mock_sleep = self._stack.enter_context(
         patch.object(RateLimiter, '_sleep'))
     self.mock_func = MagicMock()
 def setUp(self):
     self._resources = ExitStack()
     # Mock out the check_root_privilege call
     self._resources.enter_context(
         patch('ubuntu_image.classic_builder.check_root_privilege'))
     self.addCleanup(self._resources.close)
     self.gadget_tree = resource_filename(
         'ubuntu_image.tests.data', 'gadget_tree')
示例#27
0
文件: emitter.py 项目: pektin/jam
def Object_resetEmission(self):
    cache = set()
    stack = ExitStack()
    objects = iter([self])

    while True:
        obj = next(objects, None)
        if obj is None: break
        if obj in cache: continue

        cache.add(obj)
        local_resets = obj.resetLocalEmission()
        if local_resets is not None:
            stack.enter_context(local_resets)
        objects = chain(objects, obj.gatherEmissionResets())

    return stack
class WPRobotBase(object):
    def __init__(self):
        self.context = None
        self.devices = []

    def __enter__(self):
        wiringpi2.wiringPiSetupGpio()
        self.context = ExitStack()
        for device in self.devices:
            self.context.enter_context(device)
        return self

    def __exit__(self, *exc):
        self.context.__exit__(*exc)

    def attach_device(self, device):
        self.devices.append(device)
        return device
示例#29
0
def get_folders(directory: str) -> Iterable[Tuple[str, str]]:
    """
    Get folders composing a directory.
    Yield a 2-item tuples composed of both folder name and folder path.

    :param directory: directory path.
    :return: 2-item tuples composed of both folder name and folder path.
    """
    _collection = []  # type: List[Tuple[str, str]]
    stack = ExitStack()
    try:
        stack.enter_context(ChangeLocalCurrentDirectory(directory))
    except PermissionError:
        pass
    else:
        with stack:
            _collection = [(name, os.path.join(os.getcwd(), name)) for name in os.listdir(".")]
    for name, path in _collection:
        yield name, path
示例#30
0
 def setUp(self):
     old_log_level = log.getEffectiveLevel()
     self.addCleanup(log.setLevel, old_log_level)
     loop = asyncio.get_event_loop()
     self.resources = ExitStack()
     def run_forever(*args):                     # noqa: E306
         pass
     self.resources.enter_context(
         patch.object(loop, 'run_forever', run_forever))
     self.addCleanup(self.resources.close)
def recall_variants(args):

    variants, alignment_file_path, target_path, mode, germline_variants_path, germline_variants_sample, germline_bam_path, window_radius, MAX_REF_MOLECULES,max_buffer_size, debug_bam_folder = args

    window_radius = 600
    MAX_REF_MOLECULES = 5_000  # Maximum amount of reference molecules to process.
    # This is capped for regions to which many reads map (mapping artefact)

    variant_calls = dict() # cell->(chrom,pos) +/- ?
    phased_variants = dict()

    ### Set up molecule iterator (1/2)
    if mode== 'NLA':
        mc = NlaIIIMolecule
        fc = NlaIIIFragment
    else:
        mc = Molecule
        fc = Fragment

    ###
    locations_done=set()
    alignments = pysam.AlignmentFile(alignment_file_path,threads=4)
    if germline_bam_path is not None:
        germline_alignments =  pysam.AlignmentFile(germline_bam_path,threads=4)

    for variant in variants:

        # Check if the variant is present in the germline bam file (if supplied)
        if germline_bam_path is not None and has_variant_reads(
                                                        germline_alignments,
                                                        variant.chrom,
                                                        variant.pos-1,
                                                        variant.alts[0],
                                                        min_reads=1,
                                                        stepper='nofilter'):
                #print(f'FOUND IN GERMLINE {variant}')
                continue

        #print(variant)
        overlap = False
        reference_start = max(0, variant.pos - window_radius)
        reference_end = variant.pos + window_radius
        contig = variant.contig

        variant_key = (contig, variant.pos, variant.ref, variant.alts[0] )

        #print(contig,reference_start,reference_end,variant.alts[0],variant.ref)
        ### Set up allele resolver
        unphased_allele_resolver = singlecellmultiomics.alleleTools.AlleleResolver(
            use_cache=False,
            phased=False,
            verbose = True)

        if germline_variants_path is not None:
            with pysam.VariantFile(germline_variants_path) as germline:
                for i, ar_variant in enumerate(germline.fetch(
                        variant.chrom, reference_start, reference_end )):

                    if germline_variants_sample is None:
                        # If any of the samples is not heterozygous: continue
                        if any( (ar_variant.samples[sample].alleles!=2 for sample in ar_variant.samples) ):
                            continue
                    elif len(set(ar_variant.samples[germline_variants_sample].alleles))!=2:
                        continue
                    unphased_allele_resolver.locationToAllele[ar_variant.chrom][ar_variant.pos - 1] = {
                                ar_variant.alleles[0]: {'U'}, ar_variant.alleles[1]: {'V'}
                                }
        ####

        ref_phased = Counter()
        alt_phased = Counter()


        ###

        with ExitStack() as e_stack:

            if debug_bam_folder is not None:
                output_bam = e_stack.enter_context( singlecellmultiomics.bamProcessing.sorted_bam_file(
                    f'{debug_bam_folder}/{"_".join((str(x) for x in variant_key))}.bam', origin_bam=alignments))
            else:
                output_bam = None

            ### Set up molecule iterator (2/2)
            try:
                molecule_iter = MoleculeIterator(
                    alignments,
                    mc,
                    fc,
                    contig=contig,
                    start=reference_start,
                    end=reference_end,
                    molecule_class_args={
                       'allele_resolver':unphased_allele_resolver,
                        'max_associated_fragments':40,
                    },
                    max_buffer_size=max_buffer_size
                )

                reference_called_molecules = [] # molecule, phase

                extracted_base_call_count = 0
                alt_call_count = 0
                for mi,molecule in enumerate(molecule_iter):
                    base_call = get_molecule_base_calls(molecule, variant)
                    if base_call is None:
                        continue
                    extracted_base_call_count+=1
                    base, quality = base_call
                    call = None
                    if base==variant.alts[0]:
                        call='A'
                        alt_call_count+=1
                        if molecule.sample not in variant_calls:
                            variant_calls[molecule.sample] = {}
                        variant_calls[molecule.sample][variant_key] = 1

                    elif base==variant.ref:
                        call='R'

                    if debug_bam_folder is not None:
                        # Write allele-call
                        if call is None:
                            molecule.set_meta('ac','UNK')
                        else:
                            molecule.set_meta('ac', call if call != 'R' else 'UR') # We dont know yet if this is truly
                        # # reference at the allele position or just the uninformative allele

                    if call is None:
                        if output_bam is not None:
                            molecule.write_pysam(output_bam)
                        continue

                    # Obtain all germline variants which are phased :
                    phased = get_phased_variants(molecule, unphased_allele_resolver)

                    if call == 'R' and len(phased) > 0:
                        # If we can phase the alternative allele to a germline variant
                        # the reference calls can indicate absence
                        if len(reference_called_molecules) < MAX_REF_MOLECULES:
                            reference_called_molecules.append((molecule, phased))
                        else:
                            if output_bam is not None:
                                molecule.write_pysam(output_bam)
                    else:
                        if output_bam is not None:
                            molecule.write_pysam(output_bam)

                    for chrom, pos, base in phased:
                        if call == 'A':
                            alt_phased[(chrom, pos, base)] += 1

                        elif call == 'R':
                            ref_phased[(chrom, pos, base)] += 1

            except MemoryError:
                print(f"Buffer exceeded for {variant.contig} {variant.pos}")
                continue

            #print(mi,extracted_base_call_count,alt_call_count)
            if len(alt_phased) > 0 and len(reference_called_molecules):
                # Clean the alt_phased variants for variants which are not >90% the same
                alt_phased_filtered = filter_alt_calls(alt_phased, 0.9)
                #print(alt_phased_filtered)
                phased_variants[variant_key] = alt_phased_filtered
                for molecule, phased_gsnvs in reference_called_molecules:
                    for p in phased_gsnvs:
                        if p in alt_phased_filtered:
                            if not molecule.sample in variant_calls:
                                variant_calls[molecule.sample] = {}
                            variant_calls[molecule.sample][variant_key] = 0

                            if debug_bam_folder is not None:
                                molecule.set_meta('S0', True)
                                molecule.set_meta('ac', 'R')

                            break
                # And write:
                if output_bam is not None:
                    for molecule, phased_gsnvs in reference_called_molecules:
                        molecule.write_pysam(output_bam)


            locations_done.add(variant_key)


    alignments.close()
    return variant_calls, locations_done, phased_variants
def main_train(training_state: TrainingState,
               envs,
               evaluator: Evaluator = None,
               save_root=Path('saved', 'old_models')):
    params = training_state.training_parameters

    if not isinstance(envs, list):
        envs = [envs]

    device = torch.device(
        'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    print("training on: {}".format(device))

    training_state.model.to(device)
    training_state.target_model.to(device)

    rewards_queue = deque(maxlen=300)

    save_root = Path(save_root, params.model_name)

    state_save_root = Path(save_root, 'states')
    state_save_root.mkdir(exist_ok=True, parents=True)

    tensorboard_save_root = Path(save_root, 'tensorboard')
    tensorboard_save_root.mkdir(exist_ok=True, parents=True)
    print(tensorboard_save_root.resolve())
    writer = SummaryWriter(tensorboard_save_root)

    # per_step_lr_drop = 0.9 / 150000
    # scheduler = LambdaLR(training_state.optimizer, lambda step: max(0.1, 1. - step * per_step_lr_drop))
    scheduler = StepLR(training_state.optimizer, step_size=250000, gamma=0.2)

    with ExitStack() as stack:
        runners = [
            stack.enter_context(env.create_runner(render=False))
            for env in envs
        ]
        runner_it = 0

        while params.current_episode < params.num_episodes:

            runner = runners[runner_it]
            runner_it = (runner_it + 1) % len(envs)

            not_none_prev_states = runner.reset()
            prev_states = not_none_prev_states.copy()

            not_none_prev_states = normalize_states(
                not_none_prev_states, training_state.state_mean_std)
            prev_states = normalize_states(prev_states,
                                           training_state.state_mean_std)

            prev_actions = [None] * len(training_state.junctions)

            ep_len = 0
            done = False
            while not done:  # and ep_len < params.max_ep_len:
                ep_len += 1
                params.total_steps += 1

                if random.random(
                ) < params.current_eps or params.total_steps < params.pre_train_steps:
                    actions = runner.action_space.sample().tolist()
                    for i, state in enumerate(prev_states):
                        if state is None:
                            actions[i] = None
                else:
                    actions = []
                    for state in prev_states:
                        if state is None:
                            actions.append(None)
                        else:
                            tensor_state = torch.tensor([state],
                                                        dtype=torch.float32,
                                                        device=device)
                            actions.extend(
                                training_state.model(tensor_state).max(1)
                                [1].cpu().detach().numpy().tolist())

                next_states, rewards, done, info = runner.step(actions)
                next_states = normalize_states(next_states,
                                               training_state.state_mean_std)
                rewards_queue.extend(info['reward'])
                print(params.total_steps, np.mean(rewards_queue))

                for s, r, a, n_s in zip(not_none_prev_states, info['reward'],
                                        prev_actions, next_states):
                    if n_s is not None:
                        if training_state.reward_mean_std is not None:
                            r = (r - training_state.reward_mean_std[0]
                                 ) / training_state.reward_mean_std[1]

                        training_state.replay_memory.add(s, a, r, n_s, done)

                prev_states = next_states

                for i, next_state in enumerate(next_states):
                    if next_state is not None:
                        not_none_prev_states[i] = next_state

                for i, action in enumerate(actions):
                    if action is not None:
                        prev_actions[i] = action

                if params.total_steps > params.pre_train_steps:
                    if params.current_eps > params.end_e:
                        params.current_eps -= params.step_drop

                    params.sampler_current_beta += params.beta_inc
                    params.sampler_current_beta = min(
                        params.sampler_current_beta, params.sampler_beta_max)

                    if params.total_steps % params.training_freq == 0:
                        mean_loss = train_all_batches(
                            training_state.replay_memory,
                            params.sampler_current_beta,
                            params.prioritized_replay_eps,
                            training_state.model,
                            training_state.target_model,
                            training_state.optimizer,
                            training_state.loss_fn,
                            params.batch_size,
                            params.disc_factor,
                            device=device)
                        writer.add_scalar('Train/Loss', mean_loss,
                                          params.total_steps)

                    if params.total_steps % params.target_update_freq == 0:
                        update_target_net(training_state.model,
                                          training_state.target_model,
                                          params.tau)

                    if evaluator is not None and params.total_steps % params.test_freq == 0:
                        # if params.total_steps < 15000 and params.total_steps % (params.test_freq*2) != 0:
                        #     continue
                        evaluator.evaluate_to_tensorboard(
                            {
                                'model':
                                ModelController(training_state.model.eval(),
                                                device)
                            }, writer, params.total_steps,
                            training_state.state_mean_std)
                        training_state.model = training_state.model.train()

                    scheduler.step()

                    if params.total_steps % params.save_freq == 0:
                        training_state.save(
                            Path(
                                state_save_root,
                                'ep_{}_{}.tar'.format(params.total_steps,
                                                      params.model_name)))

            writer.flush()

            params.current_episode += 1

        training_state.save(
            Path(state_save_root, 'final_{}.tar'.format(params.model_name)))
示例#33
0
文件: train.py 项目: roromaniac/xfuse
def train(epochs: int = -1):
    """Trains the session model"""
    optim = require("optimizer")
    model = require("model")
    dataloader = require("dataloader")
    training_data = get("training_data")

    messengers: List[Messenger] = [
        MetagenePurger(
            period=lambda e:
            (e % 1000 == 0 and (epochs < 0 or e <= epochs - 1000)),
            num_samples=3,
        )
    ]

    if get("save_path") is not None:
        messengers.append(Checkpointer(period=1000))

        def _every(n):
            def _predicate(**_msg):
                if training_data.step % n == 0:
                    return True
                return False

            return _predicate

        writer = SummaryWriter(os.path.join(get("save_path"), "stats"))

        messengers.extend([
            stats.ELBO(writer, _every(1)),
            stats.MetageneHistogram(writer, _every(100)),
            stats.MetageneMean(writer, _every(100)),
            stats.MetageneSummary(writer, _every(1000)),
            stats.MetageneFullSummary(writer, _every(5000)),
            stats.Image(writer, _every(1000)),
            stats.Latent(writer, _every(1000)),
            stats.LogLikelihood(writer, _every(1)),
            stats.RMSE(writer, _every(1)),
            stats.Scale(writer, _every(1000)),
        ])

    @effectful(type="step")
    def _step(*, x):
        loss = pyro.infer.Trace_ELBO()
        return -pyro.infer.SVI(model.model, model.guide, optim, loss).step(x)

    @effectful(type="epoch")
    def _epoch(*, epoch):
        if isinstance(optim, pyro.optim.PyroLRScheduler):
            optim.step(epoch=epoch)
        with Progressbar(
                dataloader,
                desc=f"Epoch {epoch:05d}",
                leave=False,
        ) as iterator:
            elbo = []
            for x in iterator:
                training_data.step += 1
                elbo.append(_step(x=to_device(x)))
        return np.mean(elbo)

    with ExitStack() as stack:
        for messenger in messengers:
            stack.enter_context(messenger)

        with Progressbar(
            (it.count(training_data.epoch + 1) if epochs < 0 else range(
                training_data.epoch + 1, epochs + 1)),
                desc="Optimizing model",
                unit="epoch",
                dynamic_ncols=True,
                leave=False,
        ) as iterator:
            for epoch in iterator:
                training_data.epoch = epoch
                elbo = _epoch(epoch=epoch)
                log(
                    INFO,
                    " | ".join([
                        "Epoch %05d",
                        "ELBO %+.3e",
                        "Running ELBO %+.4e",
                        "Running RMSE %.3f",
                    ]),
                    epoch,
                    elbo,
                    training_data.elbo_long or 0.0,
                    training_data.rmse or 0.0,
                )

                if epochs < 0 and test_convergence():
                    log(DEBUG, "Model has converged, stopping")
                    break
示例#34
0
 def test_prune_inferior_points(self):
     for dtype in (torch.float, torch.double):
         X = torch.rand(3, 2, device=self.device, dtype=dtype)
         # the event shape is `q x t` = 3 x 1
         samples = torch.tensor([[-1.0], [0.0], [1.0]],
                                device=self.device,
                                dtype=dtype)
         mm = MockModel(MockPosterior(samples=samples))
         # test that a batched X raises errors
         with self.assertRaises(UnsupportedError):
             prune_inferior_points(model=mm, X=X.expand(2, 3, 2))
         # test that a batched model raises errors (event shape is `q x t` = 3 x 1)
         mm2 = MockModel(MockPosterior(samples=samples.expand(2, 3, 1)))
         with self.assertRaises(UnsupportedError):
             prune_inferior_points(model=mm2, X=X)
         # test that invalid max_frac is checked properly
         with self.assertRaises(ValueError):
             prune_inferior_points(model=mm, X=X, max_frac=1.1)
         # test basic behaviour
         X_pruned = prune_inferior_points(model=mm, X=X)
         self.assertTrue(torch.equal(X_pruned, X[[-1]]))
         # test custom objective
         neg_id_obj = GenericMCObjective(lambda X: -(X.squeeze(-1)))
         X_pruned = prune_inferior_points(model=mm,
                                          X=X,
                                          objective=neg_id_obj)
         self.assertTrue(torch.equal(X_pruned, X[[0]]))
         # test non-repeated samples (requires mocking out MockPosterior's rsample)
         samples = torch.tensor(
             [[[3.0], [0.0], [0.0]], [[0.0], [2.0], [0.0]],
              [[0.0], [0.0], [1.0]]],
             device=self.device,
             dtype=dtype,
         )
         with mock.patch.object(MockPosterior,
                                "rsample",
                                return_value=samples):
             mm = MockModel(MockPosterior(samples=samples))
             X_pruned = prune_inferior_points(model=mm, X=X)
         self.assertTrue(torch.equal(X_pruned, X))
         # test max_frac limiting
         with mock.patch.object(MockPosterior,
                                "rsample",
                                return_value=samples):
             mm = MockModel(MockPosterior(samples=samples))
             X_pruned = prune_inferior_points(model=mm, X=X, max_frac=2 / 3)
         if self.device == torch.device("cuda"):
             # sorting has different order on cuda
             self.assertTrue(
                 torch.equal(X_pruned, torch.stack([X[2], X[1]], dim=0)))
         else:
             self.assertTrue(torch.equal(X_pruned, X[:2]))
         # test that zero-probability is in fact pruned
         samples[2, 0, 0] = 10
         with mock.patch.object(MockPosterior,
                                "rsample",
                                return_value=samples):
             mm = MockModel(MockPosterior(samples=samples))
             X_pruned = prune_inferior_points(model=mm, X=X)
         self.assertTrue(torch.equal(X_pruned, X[:2]))
         # test high-dim sampling
         with ExitStack() as es:
             mock_event_shape = es.enter_context(
                 mock.patch(
                     "botorch.utils.testing.MockPosterior.event_shape",
                     new_callable=mock.PropertyMock,
                 ))
             mock_event_shape.return_value = torch.Size([1, 1, 1112])
             es.enter_context(
                 mock.patch.object(MockPosterior,
                                   "rsample",
                                   return_value=samples))
             mm = MockModel(MockPosterior(samples=samples))
             with warnings.catch_warnings(
                     record=True) as ws, settings.debug(True):
                 prune_inferior_points(model=mm, X=X)
                 self.assertTrue(
                     issubclass(ws[-1].category, SamplingWarning))
示例#35
0
 def __init__(self, file1, file2, fileformat=None, opener=xopen):
     with ExitStack() as stack:
         self.reader1 = stack.enter_context(_open_single(file1, opener=opener, fileformat=fileformat))
         self.reader2 = stack.enter_context(_open_single(file2, opener=opener, fileformat=fileformat))
         self._close = stack.pop_all().close
     self.delivers_qualities = self.reader1.delivers_qualities
示例#36
0
def inference_on_dataset(model, data_loader, evaluator):
    """
    Run model on the data_loader and evaluate the metrics with evaluator.
    Also benchmark the inference speed of `model.__call__` accurately.
    The model will be used in eval mode.

    Args:
        model (callable): a callable which takes an object from
            `data_loader` and returns some outputs.

            If it's an nn.Module, it will be temporarily set to `eval` mode.
            If you wish to evaluate a model in `training` mode instead, you can
            wrap the given model and override its behavior of `.eval()` and `.train()`.
        data_loader: an iterable object with a length.
            The elements it generates will be the inputs to the model.
        evaluator (DatasetEvaluator): the evaluator to run. Use `None` if you only want
            to benchmark, but don't want to do any evaluation.

    Returns:
        The return value of `evaluator.evaluate()`
    """
    num_devices = get_world_size()
    logger = logging.getLogger(__name__)
    logger.info("Start inference on {} images".format(len(data_loader)))

    total = len(data_loader)  # inference data loader must have a fixed length
    if evaluator is None:
        # create a no-op evaluator
        evaluator = DatasetEvaluators([])
    evaluator.reset()

    num_warmup = min(5, total - 1)
    start_time = time.perf_counter()
    total_compute_time = 0
    with ExitStack() as stack:
        if isinstance(model, nn.Module):
            stack.enter_context(inference_context(model))
        stack.enter_context(torch.no_grad())

        for idx, inputs in enumerate(data_loader):
            if idx == num_warmup:
                start_time = time.perf_counter()
                total_compute_time = 0

            start_compute_time = time.perf_counter()
            outputs = model(inputs)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            total_compute_time += time.perf_counter() - start_compute_time
            evaluator.process(inputs, outputs)

            iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
            seconds_per_img = total_compute_time / iters_after_start
            if idx >= num_warmup * 2 or seconds_per_img > 5:
                total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start
                eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1)))
                log_every_n_seconds(
                    logging.INFO,
                    "Inference done {}/{}. {:.4f} s / img. ETA={}".format(
                        idx + 1, total, seconds_per_img, str(eta)
                    ),
                    n=5,
                )

    # Measure the time only for this worker (before the synchronization barrier)
    total_time = time.perf_counter() - start_time
    total_time_str = str(datetime.timedelta(seconds=total_time))
    # NOTE this format is parsed by grep
    logger.info(
        "Total inference time: {} ({:.6f} s / img per device, on {} devices)".format(
            total_time_str, total_time / (total - num_warmup), num_devices
        )
    )
    total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time)))
    logger.info(
        "Total inference pure compute time: {} ({:.6f} s / img per device, on {} devices)".format(
            total_compute_time_str, total_compute_time / (total - num_warmup), num_devices
        )
    )

    results = evaluator.evaluate()
    # An evaluator may return None when not in main process.
    # Replace it by an empty dict instead to make it easier for downstream code to handle
    if results is None:
        results = {}
    return results
示例#37
0
def matrix_server_starter(
    free_port_generator: Iterable[Port],
    *,
    count: int = 1,
    config_generator: SynapseConfigGenerator = None,
    log_context: str = None,
) -> Iterator[List[ParsedURL]]:
    with ExitStack() as exit_stack:

        if config_generator is None:
            config_generator = exit_stack.enter_context(
                generate_synapse_config())

        server_urls: List[ParsedURL] = []
        for _, port in zip(range(count), free_port_generator):
            server_name, config_file = config_generator(port)
            server_url = ParsedURL(f"http://{server_name}")
            server_urls.append(server_url)

            synapse_cmd = [
                sys.executable,
                "-m",
                "synapse.app.homeserver",
                f"--server-name={server_name}",
                f"--config-path={config_file!s}",
            ]

            synapse_io: EXECUTOR_IO = DEVNULL
            # Used in CI to capture the logs for failure analysis
            if _SYNAPSE_LOGS_PATH is not None:
                log_file_path = Path(_SYNAPSE_LOGS_PATH).joinpath(
                    f"{server_name}.log")
                log_file_path.parent.mkdir(parents=True, exist_ok=True)
                log_file = exit_stack.enter_context(log_file_path.open("at"))

                # Preface log with header
                header = datetime.utcnow().isoformat()
                if log_context:
                    header = f"{header}: {log_context}"
                header = f" {header} "
                log_file.write(f"{header:=^100}\n")
                log_file.write(f"Cmd: `{' '.join(synapse_cmd)}`\n")
                log_file.flush()

                synapse_io = DEVNULL, log_file, STDOUT

            startup_timeout = 10
            sleep = 0.1

            executor = HTTPExecutor(
                synapse_cmd,
                url=urljoin(server_url, "/_matrix/client/versions"),
                method="GET",
                timeout=startup_timeout,
                sleep=sleep,
                cwd=config_file.parent,
                verify_tls=False,
                io=synapse_io,
            )
            exit_stack.enter_context(executor)

            # The timeout_limit_teardown is necessary to prevent the build
            # being killed because of the lack of output, at the same time the
            # timeout must never happen, because if it does, not all finalizers
            # are executed, leaving dirty state behind and resulting in test
            # flakiness.
            #
            # Because of this, this value is arbitrarily smaller than the
            # teardown timeout, forcing the subprocess to be killed on a timely
            # manner, which should allow the teardown to proceed and finish
            # before the timeout elapses.
            teardown_timeout = 0.5

            # The timeout values for the startup and teardown must be
            # different, however the library doesn't support it. So here we
            # must poke at the private member and overwrite it.
            executor._timeout = teardown_timeout

        yield server_urls
示例#38
0
def run_whatshap(
    phase_input_files,
    variant_file,
    reference=None,
    output=sys.stdout,
    samples=None,
    chromosomes=None,
    ignore_read_groups=False,
    indels=True,
    mapping_quality=20,
    read_merging=False,
    read_merging_error_rate=0.15,
    read_merging_max_error_rate=0.25,
    read_merging_positive_threshold=1000000,
    read_merging_negative_threshold=1000,
    max_coverage=15,
    full_genotyping=False,
    distrust_genotypes=False,
    include_homozygous=False,
    ped=None,
    recombrate=1.26,
    genmap=None,
    genetic_haplotyping=True,
    recombination_list_filename=None,
    tag="PS",
    read_list_filename=None,
    gl_regularizer=None,
    gtchange_list_filename=None,
    default_gq=30,
    write_command_line_header=True,
    use_ped_samples=False,
    algorithm="whatshap",
):
    """
    Run WhatsHap.

    phase_input_files -- list of paths to BAM/CRAM/VCF files
    variant_file -- path to input VCF
    reference -- path to reference FASTA
    output -- path to output VCF or a file-like object
    samples -- names of samples to phase. an empty list means: phase all samples
    chromosomes -- names of chromosomes to phase. an empty list means: phase all chromosomes
    ignore_read_groups
    mapping_quality -- discard reads below this mapping quality
    read_merging -- whether or not to merge reads
    read_merging_error_rate -- probability that a nucleotide is wrong
    read_merging_max_error_rate -- max error rate on edge of merge graph considered
    read_merging_positive_threshold -- threshold on the ratio of the two probabilities
    read_merging_negative_threshold -- threshold on the opposite ratio of positive threshold
    max_coverage
    full_genotyping
    distrust_genotypes
    include_homozygous
    genetic_haplotyping -- in ped mode, merge disconnected blocks based on genotype status
    recombination_list_filename -- filename to write putative recombination events to
    tag -- How to store phasing info in the VCF, can be 'PS' or 'HP'
    read_list_filename -- name of file to write list of used reads to
    algorithm -- algorithm to use, can be 'whatshap' or 'hapchat'
    gl_regularizer -- float to be passed as regularization constant to GenotypeLikelihoods.as_phred
    gtchange_list_filename -- filename to write list of changed genotypes to
    default_gq -- genotype likelihood to be used when GL or PL not available
    write_command_line_header -- whether to add a ##commandline header to the output VCF
    """

    if algorithm == "hapchat" and ped is not None:
        raise CommandLineError(
            "The hapchat algorithm cannot do pedigree phasing")

    timers = StageTimer()
    logger.info(
        "This is WhatsHap %s running under Python %s",
        __version__,
        platform.python_version(),
    )
    if full_genotyping:
        distrust_genotypes = True
        include_homozygous = True
    numeric_sample_ids = NumericSampleIds()
    if write_command_line_header:
        command_line = "(whatshap {}) {}".format(__version__,
                                                 " ".join(sys.argv[1:]))
    else:
        command_line = None

    if read_merging:
        read_merger = ReadMerger(
            read_merging_error_rate,
            read_merging_max_error_rate,
            read_merging_positive_threshold,
            read_merging_negative_threshold,
        )
    else:
        read_merger = DoNothingReadMerger()

    with ExitStack() as stack:
        try:
            vcf_writer = stack.enter_context(
                PhasedVcfWriter(
                    command_line=command_line,
                    in_path=variant_file,
                    out_file=output,
                    tag=tag,
                ))
        except (OSError, VcfError) as e:
            raise CommandLineError(e)

        phased_input_reader = stack.enter_context(
            PhasedInputReader(
                phase_input_files,
                reference,
                numeric_sample_ids,
                ignore_read_groups,
                mapq_threshold=mapping_quality,
                indels=indels,
            ))
        show_phase_vcfs = phased_input_reader.has_vcfs

        # Only read genotype likelihoods from VCFs when distrusting genotypes
        vcf_reader = stack.enter_context(
            VcfReader(variant_file,
                      indels=indels,
                      genotype_likelihoods=distrust_genotypes))

        if ignore_read_groups and not samples and len(vcf_reader.samples) > 1:
            raise CommandLineError(
                "When using --ignore-read-groups on a VCF with "
                "multiple samples, --sample must also be used.")
        if not samples:
            samples = vcf_reader.samples

        # if --use-ped-samples is set, use only samples from PED file
        if ped and use_ped_samples:
            samples = PedReader(ped).samples()

        raise_if_any_sample_not_in_vcf(vcf_reader, samples)

        if ped and genmap:
            logger.info(
                "Using region-specific recombination rates from genetic map %s.",
                genmap,
            )
            try:
                recombination_cost_computer = GeneticMapRecombinationCostComputer(
                    genmap)
            except ParseError as e:
                raise CommandLineError(e)
        else:
            if ped:
                logger.info("Using uniform recombination rate of %g cM/Mb.",
                            recombrate)
            recombination_cost_computer = UniformRecombinationCostComputer(
                recombrate)

        samples = frozenset(samples)
        families, family_trios = setup_families(samples, ped,
                                                numeric_sample_ids,
                                                max_coverage)

        read_list = None
        if read_list_filename:
            read_list = stack.enter_context(ReadList(read_list_filename))
            if algorithm == "hapchat":
                logger.warning(
                    "On which haplotype a read occurs in the inferred solution is not yet "
                    "implemented in hapchat, and so the corresponding column in the "
                    "read list file contains no information about this")

        with timers("parse_phasing_vcfs"):
            # TODO should this be done in PhasedInputReader.__init__?
            phased_input_reader.read_vcfs()

        for variant_table in timers.iterate("parse_vcf", vcf_reader):
            chromosome = variant_table.chromosome
            if (not chromosomes) or (chromosome in chromosomes):
                logger.info("======== Working on chromosome %r", chromosome)
            else:
                logger.info(
                    "Leaving chromosome %r unchanged (present in VCF but not requested by option --chromosome)",
                    chromosome,
                )
                with timers("write_vcf"):
                    superreads, components = dict(), dict()
                    vcf_writer.write(chromosome, superreads, components)
                continue

            if full_genotyping:
                positions = [v.position for v in variant_table.variants]
                for sample in samples:
                    logger.info("---- Initial genotyping of %s", sample)
                    with timers("read_bam"):
                        bam_sample = None if ignore_read_groups else sample
                        readset, vcf_source_ids = phased_input_reader.read(
                            chromosome,
                            variant_table.variants,
                            bam_sample,
                            read_vcf=False,
                        )
                        readset.sort()  # TODO can be removed
                        genotypes, genotype_likelihoods = compute_genotypes(
                            readset, positions)
                        variant_table.set_genotypes_of(sample, genotypes)
                        variant_table.set_genotype_likelihoods_of(
                            sample,
                            [
                                GenotypeLikelihoods(gl)
                                for gl in genotype_likelihoods
                            ],
                        )

            # These two variables hold the phasing results for all samples
            superreads, components = dict(), dict()

            # Iterate over all families to process, i.e. a separate DP table is created
            # for each family.
            # TODO: Can the body of this loop be factored out into a phase_family function?
            for representative_sample, family in sorted(families.items()):
                if len(family) == 1:
                    logger.info("---- Processing individual %s",
                                representative_sample)
                else:
                    logger.info("---- Processing family with individuals: %s",
                                ",".join(family))
                max_coverage_per_sample = max(1, max_coverage // len(family))
                logger.info("Using maximum coverage per sample of %dX",
                            max_coverage_per_sample)
                trios = family_trios[representative_sample]
                assert len(family) == 1 or len(trios) > 0

                homozygous_positions, phasable_variant_table = find_phaseable_variants(
                    family, include_homozygous, trios, variant_table)

                # Get the reads belonging to each sample
                readsets = dict()  # TODO this could become a list
                for sample in family:
                    with timers("read_bam"):
                        readset, vcf_source_ids = phased_input_reader.read(
                            chromosome,
                            phasable_variant_table.variants,
                            sample,
                        )

                    # TODO: Read selection done w.r.t. all variants, where using heterozygous
                    #  variants only would probably give better results.
                    with timers("select"):
                        readset = readset.subset([
                            i for i, read in enumerate(readset)
                            if len(read) >= 2
                        ])
                        logger.info(
                            "Kept %d reads that cover at least two variants each",
                            len(readset),
                        )
                        merged_reads = read_merger.merge(readset)
                        selected_reads = select_reads(
                            merged_reads,
                            max_coverage_per_sample,
                            preferred_source_ids=vcf_source_ids,
                        )

                    readsets[sample] = selected_reads
                    if len(family) == 1 and not distrust_genotypes:
                        # When having a pedigree (len(family) > 1), blocks are also merged after
                        # phasing based on the pedigree information and these statistics are not
                        # so useful. When distrust_genotypes, genotypes can change during phasing
                        # and so can the block structure. So don't print these stats in those cases
                        log_best_case_phasing_info(readset, selected_reads)

                all_reads = merge_readsets(readsets)

                # Determine which variants can (in principle) be phased
                accessible_positions = sorted(all_reads.get_positions())
                logger.info(
                    "Variants covered by at least one phase-informative "
                    "read in at least one individual after read selection: %d",
                    len(accessible_positions),
                )
                if len(family) > 1 and genetic_haplotyping:
                    # In case of genetic haplotyping, also retain all positions homozygous
                    # in at least one individual (because they might be phased based on genotypes)
                    accessible_positions = sorted(
                        set(accessible_positions).union(homozygous_positions))
                    logger.info(
                        "Variants either covered by phase-informative read or homozygous "
                        "in at least one individual: %d",
                        len(accessible_positions),
                    )

                # Keep only accessible positions
                phasable_variant_table.subset_rows_by_position(
                    accessible_positions)
                assert len(phasable_variant_table.variants) == len(
                    accessible_positions)

                pedigree = create_pedigree(
                    default_gq,
                    distrust_genotypes,
                    family,
                    gl_regularizer,
                    numeric_sample_ids,
                    phasable_variant_table,
                    trios,
                )

                recombination_costs = recombination_cost_computer.compute(
                    accessible_positions)

                # Finally, run phasing algorithm
                with timers("phase"):
                    problem_name = "MEC" if len(family) == 1 else "PedMEC"
                    logger.info(
                        "Phasing %d sample%s by solving the %s problem ...",
                        len(family),
                        plural_s(len(family)),
                        problem_name,
                    )

                    if algorithm == "hapchat":
                        dp_table = HapChatCore(all_reads)
                    else:
                        dp_table = PedigreeDPTable(
                            all_reads,
                            recombination_costs,
                            pedigree,
                            distrust_genotypes,
                            accessible_positions,
                        )

                    superreads_list, transmission_vector = dp_table.get_super_reads(
                    )
                    optimal_cost = dp_table.get_optimal_cost()
                    logger.info("%s cost: %d", problem_name, optimal_cost)

                with timers("components"):
                    master_block = None
                    heterozygous_positions_by_sample = None
                    # If we distrusted genotypes, we need to re-determine which sites are h**o-/heterozygous after phasing
                    if distrust_genotypes:
                        hom_in_any_sample = set()
                        heterozygous_positions_by_sample = {}
                        heterozygous_gts = frozenset({(0, 1), (1, 0)})
                        homozygous_gts = frozenset({(0, 0), (1, 1)})
                        for sample, sample_superreads in zip(
                                family, superreads_list):
                            hets = set()
                            for v1, v2 in zip(*sample_superreads):
                                assert v1.position == v2.position
                                if v1.position not in accessible_positions:
                                    continue
                                gt = (v1.allele, v2.allele)
                                if gt in heterozygous_gts:
                                    hets.add(v1.position)
                                elif gt in homozygous_gts:
                                    hom_in_any_sample.add(v1.position)
                            heterozygous_positions_by_sample[
                                numeric_sample_ids[sample]] = hets
                        if len(family) > 1 and genetic_haplotyping:
                            master_block = sorted(hom_in_any_sample)
                    else:
                        if len(family) > 1 and genetic_haplotyping:
                            master_block = sorted(
                                set(homozygous_positions).intersection(
                                    set(accessible_positions)))
                    overall_components = find_components(
                        accessible_positions,
                        all_reads,
                        master_block,
                        heterozygous_positions_by_sample,
                    )
                    n_phased_blocks = len(set(overall_components.values()))
                    logger.info("No. of phased blocks: %d", n_phased_blocks)
                    largest_component = find_largest_component(
                        overall_components)
                    if len(largest_component) > 0:
                        logger.info(
                            "Largest component contains %d variants (%.1f%% of accessible variants) between position %d and %d",
                            len(largest_component),
                            len(largest_component) * 100.0 /
                            len(accessible_positions),
                            largest_component[0] + 1,
                            largest_component[-1] + 1,
                        )

                if recombination_list_filename:
                    n_recombinations = write_recombination_list(
                        recombination_list_filename,
                        chromosome,
                        accessible_positions,
                        overall_components,
                        recombination_costs,
                        transmission_vector,
                        trios,
                    )
                    logger.info(
                        "Total no. of detected recombination events: %d",
                        n_recombinations,
                    )

                # Superreads in superreads_list are in the same order as individuals were added to the pedigree
                for sample, sample_superreads in zip(family, superreads_list):
                    superreads[sample] = sample_superreads
                    assert len(sample_superreads) == 2
                    assert (sample_superreads[0].sample_id ==
                            sample_superreads[1].sample_id ==
                            numeric_sample_ids[sample])
                    # identical for all samples
                    components[sample] = overall_components

                if read_list:
                    read_list.write(
                        all_reads,
                        dp_table.get_optimal_partitioning(),
                        components,
                        numeric_sample_ids,
                    )

            with timers("write_vcf"):
                logger.info("======== Writing VCF")
                changed_genotypes = vcf_writer.write(chromosome, superreads,
                                                     components)
                logger.info("Done writing VCF")
                if changed_genotypes:
                    assert distrust_genotypes
                    logger.info("Changed %d genotypes while writing VCF",
                                len(changed_genotypes))

            if gtchange_list_filename:
                logger.info("Writing list of changed genotypes to %r",
                            gtchange_list_filename)
                write_changed_genotypes(gtchange_list_filename,
                                        changed_genotypes)

            logger.debug("Chromosome %r finished", chromosome)

    log_time_and_memory_usage(timers, show_phase_vcfs=show_phase_vcfs)
示例#39
0
def sdk_signed_random_requests(looper, sdk_wallet, count):
    _, did = sdk_wallet
    reqs_obj = sdk_random_request_objects(
        count, identifier=did, protocol_version=CURRENT_PROTOCOL_VERSION)
    return sdk_sign_request_objects(looper, sdk_wallet, reqs_obj)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('count', help="Count of generated txns", type=int)
    parser.add_argument('outfpath',
                        help="Path to save generated txns",
                        type=str,
                        default='/tmp/generated_txns')
    args = parser.parse_args()
    path_to_save = os.path.realpath(args.outfpath)

    with ExitStack() as exit_stack:
        with Looper() as looper:
            sdk_wallet, did = looper.loop.run_until_complete(
                get_wallet_and_pool())
            with open(path_to_save, 'w') as outpath:
                for _ in range(args.count):
                    req = sdk_signed_random_requests(looper, (sdk_wallet, did),
                                                     1)[0]
                    txn = sdk_reqToTxn(req, int(time.time()))
                    outpath.write(json.dumps(txn))
                    outpath.write(os.linesep)
            looper.stopall()
    def test_fixed_evaluation_qMFKG(self):
        # mock test qMFKG.evaluate() with expand, project & cost aware utility
        for dtype in (torch.float, torch.double):
            mean = torch.zeros(1, 1, 1, device=self.device, dtype=dtype)
            mm = MockModel(MockPosterior(mean=mean))
            cau = GenericCostAwareUtility(mock_util)
            n_f = 4
            mean = torch.rand(n_f, 2, 1, 1, device=self.device, dtype=dtype)
            variance = torch.rand(n_f,
                                  2,
                                  1,
                                  1,
                                  device=self.device,
                                  dtype=dtype)
            mfm = MockModel(MockPosterior(mean=mean, variance=variance))
            with ExitStack() as es:
                patch_f = es.enter_context(
                    mock.patch.object(MockModel, "fantasize",
                                      return_value=mfm))
                mock_num_outputs = es.enter_context(
                    mock.patch(NO, new_callable=mock.PropertyMock))
                es.enter_context(
                    mock.patch(
                        "botorch.optim.optimize.optimize_acqf",
                        return_value=(
                            torch.ones(1,
                                       1,
                                       1,
                                       device=self.device,
                                       dtype=dtype),
                            torch.ones(1, device=self.device, dtype=dtype),
                        ),
                    ), )
                es.enter_context(
                    mock.patch(
                        "botorch.generation.gen.gen_candidates_scipy",
                        return_value=(
                            torch.ones(1,
                                       1,
                                       1,
                                       device=self.device,
                                       dtype=dtype),
                            torch.ones(1, device=self.device, dtype=dtype),
                        ),
                    ), )

                mock_num_outputs.return_value = 1
                qMFKG = qMultiFidelityKnowledgeGradient(
                    model=mm,
                    num_fantasies=n_f,
                    X_pending=torch.rand(1,
                                         1,
                                         1,
                                         device=self.device,
                                         dtype=dtype),
                    current_value=torch.zeros(1,
                                              device=self.device,
                                              dtype=dtype),
                    cost_aware_utility=cau,
                    project=lambda X: torch.zeros_like(X),
                    expand=lambda X: torch.ones_like(X),
                )
                val = qMFKG.evaluate(
                    X=torch.zeros(1, 1, 1, device=self.device, dtype=dtype),
                    bounds=torch.tensor([[0.0], [1.0]]),
                    num_restarts=1,
                    raw_samples=1,
                )
                patch_f.asset_called_once()
                cargs, ckwargs = patch_f.call_args
                self.assertTrue(
                    torch.equal(
                        ckwargs["X"],
                        torch.ones(1, 2, 1, device=self.device, dtype=dtype),
                    ))
            self.assertEqual(
                val, cau(None, torch.ones(1, device=self.device, dtype=dtype)))
示例#41
0
def configure_contextual_logging(_ctx=ExitStack(), **kw):
    indentation = int(os.getenv("EASYPY_LOG_INDENTATION", "0"))
    _ctx.enter_context(THREAD_LOGGING_CONTEXT(indentation=indentation, **kw))
示例#42
0
async def handle_debdiff(request):
    old_id = request.match_info["old_id"]
    new_id = request.match_info["new_id"]

    old_run, new_run = await get_run_pair(request.app.db, old_id, new_id)

    cache_path = request.app.debdiff_cache_path(old_run['id'], new_run['id'])
    if cache_path:
        try:
            with open(cache_path, "rb") as f:
                debdiff = f.read()
        except FileNotFoundError:
            debdiff = None
    else:
        debdiff = None

    if debdiff is None:
        logging.info(
            "Generating debdiff between %s (%s/%s/%s) and %s (%s/%s/%s)",
            old_run['id'],
            old_run['package'],
            old_run['build_version'],
            old_run['suite'],
            new_run['id'],
            new_run['package'],
            new_run['build_version'],
            new_run['suite'],
        )
        with ExitStack() as es:
            old_dir = es.enter_context(TemporaryDirectory())
            new_dir = es.enter_context(TemporaryDirectory())

            try:
                await asyncio.gather(
                    request.app.artifact_manager.retrieve_artifacts(
                        old_run['id'], old_dir, filter_fn=is_binary
                    ),
                    request.app.artifact_manager.retrieve_artifacts(
                        new_run['id'], new_dir, filter_fn=is_binary
                    ),
                )
            except ArtifactsMissing as e:
                raise web.HTTPNotFound(
                    text="No artifacts for run id: %r" % e,
                    headers={"unavailable_run_id": e.args[0]},
                )
            except asyncio.TimeoutError:
                raise web.HTTPGatewayTimeout(text="Timeout retrieving artifacts")

            old_binaries = find_binaries(old_dir)
            if not old_binaries:
                raise web.HTTPNotFound(
                    text="No artifacts for run id: %s" % old_run['id'],
                    headers={"unavailable_run_id": old_run['id']},
                )

            new_binaries = find_binaries(new_dir)
            if not new_binaries:
                raise web.HTTPNotFound(
                    text="No artifacts for run id: %s" % new_run['id'],
                    headers={"unavailable_run_id": new_run['id']},
                )

            try:
                debdiff = await run_debdiff(
                    [p for (n, p) in old_binaries], [p for (n, p) in new_binaries]
                )
            except DebdiffError as e:
                return web.Response(status=400, text=e.args[0])

        if cache_path:
            with open(cache_path, "wb") as f:
                f.write(debdiff)

    if "filter_boring" in request.query:
        debdiff = filter_debdiff_boring(
            debdiff.decode(), str(old_run['build_version']), str(new_run['build_version'])
        ).encode()

    for accept in request.headers.get("ACCEPT", "*/*").split(","):
        if accept in ("text/x-diff", "text/plain", "*/*"):
            return web.Response(body=debdiff, content_type="text/plain")
        if accept == "text/markdown":
            return web.Response(
                text=markdownify_debdiff(debdiff.decode("utf-8", "replace")),
                content_type="text/markdown",
            )
        if accept == "text/html":
            return web.Response(
                text=htmlize_debdiff(debdiff.decode("utf-8", "replace")),
                content_type="text/html",
            )
    raise web.HTTPNotAcceptable(
        text="Acceptable content types: " "text/html, text/plain, text/markdown"
    )
示例#43
0
async def handle_diffoscope(request):
    for accept in request.headers.get("ACCEPT", "*/*").split(","):
        if accept in ("text/plain", "*/*"):
            content_type = "text/plain"
            break
        elif accept in ("text/html",):
            content_type = "text/html"
            break
        elif accept in ("application/json",):
            content_type = "application/json"
            break
        elif accept in ("text/markdown",):
            content_type = "text/markdown"
            break
    else:
        raise web.HTTPNotAcceptable(
            text="Acceptable content types: "
            "text/html, text/plain, application/json, "
            "application/markdown"
        )

    old_id = request.match_info["old_id"]
    new_id = request.match_info["new_id"]

    old_run, new_run = await get_run_pair(request.app.db, old_id, new_id)

    cache_path = request.app.diffoscope_cache_path(old_run['id'], new_run['id'])
    if cache_path:
        try:
            with open(cache_path, "rb") as f:
                diffoscope_diff = json.load(f)
        except FileNotFoundError:
            diffoscope_diff = None
    else:
        diffoscope_diff = None

    if diffoscope_diff is None:
        logging.info(
            "Generating diffoscope between %s (%s/%s/%s) and %s (%s/%s/%s)",
            old_run['id'],
            old_run['package'],
            old_run['build_version'],
            old_run['suite'],
            new_run['id'],
            new_run['package'],
            new_run['build_version'],
            new_run['suite'],
        )
        with ExitStack() as es:
            old_dir = es.enter_context(TemporaryDirectory())
            new_dir = es.enter_context(TemporaryDirectory())

            try:
                await asyncio.gather(
                    request.app.artifact_manager.retrieve_artifacts(
                        old_run['id'], old_dir, filter_fn=is_binary
                    ),
                    request.app.artifact_manager.retrieve_artifacts(
                        new_run['id'], new_dir, filter_fn=is_binary
                    ),
                )
            except ArtifactsMissing as e:
                raise web.HTTPNotFound(
                    text="No artifacts for run id: %r" % e,
                    headers={"unavailable_run_id": e.args[0]},
                )
            except asyncio.TimeoutError:
                raise web.HTTPGatewayTimeout(text="Timeout retrieving artifacts")

            old_binaries = find_binaries(old_dir)
            if not old_binaries:
                raise web.HTTPNotFound(
                    text="No artifacts for run id: %s" % old_run['id'],
                    headers={"unavailable_run_id": old_run['id']},
                )

            new_binaries = find_binaries(new_dir)
            if not new_binaries:
                raise web.HTTPNotFound(
                    text="No artifacts for run id: %s" % new_run['id'],
                    headers={"unavailable_run_id": new_run['id']},
                )

            try:
                diffoscope_diff = await asyncio.wait_for(
                    run_diffoscope(
                        old_binaries, new_binaries,
                        lambda: _set_limits(request.app.task_memory_limit)),
                    request.app.task_timeout
                )
            except MemoryError:
                raise web.HTTPServiceUnavailable(text="diffoscope used too much memory")
            except asyncio.TimeoutError:
                raise web.HTTPGatewayTimeout(text="diffoscope timed out")

        if cache_path is not None:
            with open(cache_path, "w") as f:
                json.dump(diffoscope_diff, f)

    diffoscope_diff["source1"] = "%s version %s (%s)" % (
        old_run['package'],
        old_run['build_version'],
        old_run['suite'],
    )
    diffoscope_diff["source2"] = "%s version %s (%s)" % (
        new_run['package'],
        new_run['build_version'],
        new_run['suite'],
    )

    filter_diffoscope_irrelevant(diffoscope_diff)

    title = "diffoscope for %s applied to %s" % (new_run['suite'], new_run['package'])

    if "filter_boring" in request.query:
        filter_diffoscope_boring(
            diffoscope_diff,
            str(old_run['build_version']),
            str(new_run['build_version']),
            old_run['suite'],
            new_run['suite'],
        )
        title += " (filtered)"

    debdiff = await format_diffoscope(
        diffoscope_diff,
        content_type,
        title=title,
        jquery_url=request.query.get("jquery_url"),
        css_url=request.query.get("css_url"),
    )

    return web.Response(text=debdiff, content_type=content_type)
示例#44
0
 def setUp(self):
     super().setUp()
     self._resources = ExitStack()
     self.addCleanup(self._resources.close)
     self.model_assertion = resource_filename('ubuntu_image.tests.data',
                                              'model.assertion')
示例#45
0
class TestMainWithGadget(TestCase):
    def setUp(self):
        super().setUp()
        self._resources = ExitStack()
        self.addCleanup(self._resources.close)
        # Capture builtin print() output.
        self._stdout = StringIO()
        self._stderr = StringIO()
        self._resources.enter_context(
            patch('argparse._sys.stdout', self._stdout))
        # Capture stderr since this is where argparse will spew to.
        self._resources.enter_context(
            patch('argparse._sys.stderr', self._stderr))
        # Set up a few other useful things for these tests.
        self._resources.enter_context(
            patch('ubuntu_image.__main__.logging.basicConfig'))
        self.model_assertion = resource_filename('ubuntu_image.tests.data',
                                                 'model.assertion')
        self.classic_gadget_tree = resource_filename('ubuntu_image.tests.data',
                                                     'gadget_tree')

    def test_output_without_subcommand(self):
        self._resources.enter_context(
            patch('ubuntu_image.__main__.ModelAssertionBuilder',
                  DoNothingBuilder))
        tmpdir = self._resources.enter_context(TemporaryDirectory())
        imgfile = os.path.join(tmpdir, 'my-disk.img')
        self.assertFalse(os.path.exists(imgfile))
        main(('--output', imgfile, self.model_assertion))
        self.assertTrue(os.path.exists(imgfile))

    def test_output(self):
        self._resources.enter_context(
            patch('ubuntu_image.__main__.ModelAssertionBuilder',
                  DoNothingBuilder))
        tmpdir = self._resources.enter_context(TemporaryDirectory())
        imgfile = os.path.join(tmpdir, 'my-disk.img')
        self.assertFalse(os.path.exists(imgfile))
        main(('snap', '--output', imgfile, self.model_assertion))
        self.assertTrue(os.path.exists(imgfile))

    def test_output_directory(self):
        self._resources.enter_context(
            patch('ubuntu_image.__main__.ModelAssertionBuilder',
                  DoNothingBuilder))
        tmpdir = self._resources.enter_context(TemporaryDirectory())
        outputdir = os.path.join(tmpdir, 'images')
        main(('snap', '--output-dir', outputdir, self.model_assertion))
        self.assertTrue(os.path.exists(os.path.join(outputdir, 'pc.img')))

    def test_output_directory_multiple_images(self):
        class Builder(DoNothingBuilder):
            gadget_yaml = 'gadget-multi.yaml'

        self._resources.enter_context(
            patch('ubuntu_image.__main__.ModelAssertionBuilder', Builder))
        # Quiet the test suite.
        self._resources.enter_context(
            patch('ubuntu_image.parser._logger.warning'))
        tmpdir = self._resources.enter_context(TemporaryDirectory())
        outputdir = os.path.join(tmpdir, 'images')
        main(('snap', '-O', outputdir, self.model_assertion))
        for name in ('first', 'second', 'third', 'fourth'):
            image_path = os.path.join(outputdir, '{}.img'.format(name))
            self.assertTrue(os.path.exists(image_path))

    def test_output_directory_multiple_images_image_file_list(self):
        class Builder(DoNothingBuilder):
            gadget_yaml = 'gadget-multi.yaml'

        self._resources.enter_context(
            patch('ubuntu_image.__main__.ModelAssertionBuilder', Builder))
        # Quiet the test suite.
        self._resources.enter_context(
            patch('ubuntu_image.parser._logger.warning'))
        tmpdir = self._resources.enter_context(TemporaryDirectory())
        outputdir = os.path.join(tmpdir, 'images')
        image_file_list = os.path.join(tmpdir, 'ifl.txt')
        main(('snap', '-O', outputdir, '--image-file-list', image_file_list,
              self.model_assertion))
        with open(image_file_list, 'r', encoding='utf-8') as fp:
            img_files = set(line.rstrip() for line in fp.readlines())
        self.assertEqual(
            img_files,
            set(
                os.path.join(outputdir, '{}.img'.format(filename))
                for filename in ('first', 'second', 'third', 'fourth')))

    def test_output_image_file_list(self):
        self._resources.enter_context(
            patch('ubuntu_image.__main__.ModelAssertionBuilder',
                  DoNothingBuilder))
        # Quiet the test suite.
        self._resources.enter_context(
            patch('ubuntu_image.parser._logger.warning'))
        tmpdir = self._resources.enter_context(TemporaryDirectory())
        output = os.path.join(tmpdir, 'pc.img')
        image_file_list = os.path.join(tmpdir, 'ifl.txt')
        main(('snap', '-o', output, '--image-file-list', image_file_list,
              self.model_assertion))
        with open(image_file_list, 'r', encoding='utf-8') as fp:
            img_files = set(line.rstrip() for line in fp.readlines())
        self.assertEqual(img_files, {output})

    def test_tmp_okay_for_classic_snap(self):
        # For reference see:
        # http://snapcraft.io/docs/reference/env
        self._resources.enter_context(envar('SNAP_NAME', 'crack-pop'))
        self._resources.enter_context(chdir('/tmp'))
        self._resources.enter_context(
            patch('ubuntu_image.__main__.ModelAssertionBuilder',
                  DoNothingBuilder))
        code = main(('snap', '--output-dir', '/tmp/images', '--extra-snaps',
                     '/tmp/extra.snap', '/tmp/model.assertion'))
        self.assertEqual(code, 0)
        self.assertTrue(os.path.exists('/tmp/images/pc.img'))

    def test_resume_and_model_assertion(self):
        with self.assertRaises(SystemExit) as cm:
            main(('snap', '--resume', self.model_assertion))
        self.assertEqual(cm.exception.code, 2)

    def test_resume_and_model_assertion_without_subcommand(self):
        with self.assertRaises(SystemExit) as cm:
            main(('--resume', self.model_assertion))
        self.assertEqual(cm.exception.code, 2)

    def test_no_resume_and_no_model_assertion(self):
        with self.assertRaises(SystemExit) as cm:
            main(('--until', 'whatever'))
        self.assertEqual(cm.exception.code, 2)

    def test_resume_without_workdir(self):
        with self.assertRaises(SystemExit) as cm:
            main(('snap', '--resume'))
        self.assertEqual(cm.exception.code, 2)

    def test_resume_without_workdir_without_subcommand(self):
        with self.assertRaises(SystemExit) as cm:
            main(('--resume', ))
        self.assertEqual(cm.exception.code, 2)

    def test_resume_and_gadget_tree(self):
        with self.assertRaises(SystemExit) as cm:
            main(('classic', '--resume', self.classic_gadget_tree))
        self.assertEqual(cm.exception.code, 2)

    def test_no_resume_and_no_gadget_tree(self):
        with self.assertRaises(SystemExit) as cm:
            main(('classic', '--until', 'whatever'))
        self.assertEqual(cm.exception.code, 2)

    @skipIf('UBUNTU_IMAGE_TESTS_NO_NETWORK' in os.environ,
            'Cannot run this test without network access')
    def test_save_resume(self):
        self._resources.enter_context(
            patch('ubuntu_image.__main__.ModelAssertionBuilder',
                  XXXModelAssertionBuilder))
        workdir = self._resources.enter_context(TemporaryDirectory())
        imgfile = os.path.join(workdir, 'my-disk.img')
        main(('--until', 'prepare_filesystems', '--channel', 'edge',
              '--workdir', workdir, '--output', imgfile, self.model_assertion))
        self.assertTrue(
            os.path.exists(os.path.join(workdir, '.ubuntu-image.pck')))
        self.assertFalse(os.path.exists(imgfile))
        main(('snap', '--resume', '--workdir', workdir))
        self.assertTrue(os.path.exists(imgfile))

    def test_until(self):
        workdir = self._resources.enter_context(TemporaryDirectory())
        self._resources.enter_context(
            patch('ubuntu_image.__main__.ModelAssertionBuilder',
                  DoNothingBuilder))
        main(('snap', '--until', 'populate_rootfs_contents', '--channel',
              'edge', '--workdir', workdir, self.model_assertion))
        # The pickle file will tell us how far the state machine got.
        with open(os.path.join(workdir, '.ubuntu-image.pck'), 'rb') as fp:
            pickle_state = load(fp).__getstate__()
        # This is the *next* state to execute.
        self.assertEqual(pickle_state['state'], ['populate_rootfs_contents'])

    def test_thru(self):
        workdir = self._resources.enter_context(TemporaryDirectory())
        self._resources.enter_context(
            patch('ubuntu_image.__main__.ModelAssertionBuilder',
                  DoNothingBuilder))
        main(('snap', '--thru', 'populate_rootfs_contents', '--workdir',
              workdir, '--channel', 'edge', self.model_assertion))
        # The pickle file will tell us how far the state machine got.
        with open(os.path.join(workdir, '.ubuntu-image.pck'), 'rb') as fp:
            pickle_state = load(fp).__getstate__()
        # This is the *next* state to execute.
        self.assertEqual(pickle_state['state'],
                         ['populate_rootfs_contents_hooks'])

    def test_resume_loads_pickle_snap(self):
        workdir = self._resources.enter_context(TemporaryDirectory())
        self._resources.enter_context(
            patch('ubuntu_image.__main__.ModelAssertionBuilder',
                  EarlyExitLeaveATraceAssertionBuilder))
        main(('snap', '--until', 'prepare_image', '--workdir', workdir,
              self.model_assertion))
        self.assertFalse(os.path.exists(os.path.join(workdir, 'success')))
        main(('--workdir', workdir, '--resume'))
        self.assertTrue(os.path.exists(os.path.join(workdir, 'success')))

    def test_resume_loads_pickle_classic(self):
        workdir = self._resources.enter_context(TemporaryDirectory())
        self._resources.enter_context(
            patch('ubuntu_image.__main__.ClassicBuilder',
                  EarlyExitLeaveATraceClassicBuilder))
        self._resources.enter_context(
            patch('ubuntu_image.classic_builder.check_root_privilege'))
        main(('classic', '--until', 'prepare_image', '--workdir', workdir,
              '--project', 'ubuntu-cpc', self.classic_gadget_tree))
        self.assertFalse(os.path.exists(os.path.join(workdir, 'success')))
        main(('--workdir', workdir, '--resume'))
        self.assertTrue(os.path.exists(os.path.join(workdir, 'success')))

    @skipIf('UBUNTU_IMAGE_TESTS_NO_NETWORK' in os.environ,
            'Cannot run this test without network access')
    def test_does_not_fit(self):
        # The contents of a structure is too large for the image size.
        workdir = self._resources.enter_context(TemporaryDirectory())
        # See LP: #1666580
        main(('snap', '--workdir', workdir, '--thru', 'load_gadget_yaml',
              self.model_assertion))
        # Make the gadget's mbr contents too big.
        path = os.path.join(workdir, 'unpack', 'gadget', 'pc-boot.img')
        os.truncate(path, 512)
        mock = self._resources.enter_context(
            patch('ubuntu_image.__main__._logger.error'))
        code = main(('snap', '--workdir', workdir, '--resume'))
        self.assertEqual(code, 1)
        self.assertEqual(
            mock.call_args_list[-1],
            call('Volume contents do not fit (72B over): '
                 'volumes:<pc>:structure:<mbr> [#0]'))

    def test_classic_not_privileged(self):
        workdir = self._resources.enter_context(TemporaryDirectory())
        self._resources.enter_context(
            patch('ubuntu_image.__main__.ClassicBuilder',
                  EarlyExitLeaveATraceClassicBuilder))
        self._resources.enter_context(patch('os.geteuid', return_value=1))
        self._resources.enter_context(
            patch('pwd.getpwuid', return_value=['test']))
        mock = self._resources.enter_context(
            patch('ubuntu_image.__main__._logger.error'))
        code = main(('classic', '--workdir', workdir, '--project',
                     'ubuntu-cpc', self.classic_gadget_tree))
        self.assertEqual(code, 1)
        self.assertFalse(os.path.exists(os.path.join(workdir, 'success')))
        self.assertEqual(
            mock.call_args_list[-1],
            call('Current user(test) does not have root privilege to build '
                 'classic image. Please run ubuntu-image as root.'))

    def test_classic_cross_build_no_static(self):
        # We need to check that a DependencyError is raised when
        # find_executable does not find the qemu-<ARCH>-static binary in
        # PATH (and no path env is set)
        workdir = self._resources.enter_context(TemporaryDirectory())
        livecd_rootfs = self._resources.enter_context(TemporaryDirectory())
        auto = os.path.join(livecd_rootfs, 'auto')
        os.mkdir(auto)
        self._resources.enter_context(
            patch('ubuntu_image.__main__.ClassicBuilder',
                  CallLBLeaveATraceClassicBuilder))
        self._resources.enter_context(
            envar('UBUNTU_IMAGE_LIVECD_ROOTFS_AUTO_PATH', auto))
        self._resources.enter_context(
            patch('ubuntu_image.helpers.run', return_value=None))
        self._resources.enter_context(
            patch('ubuntu_image.helpers.find_executable', return_value=None))
        self._resources.enter_context(
            patch('ubuntu_image.helpers.get_host_arch', return_value='amd64'))
        self._resources.enter_context(
            patch('ubuntu_image.__main__.get_host_distro',
                  return_value='bionic'))
        self._resources.enter_context(
            patch('ubuntu_image.classic_builder.check_root_privilege',
                  return_value=None))
        mock = self._resources.enter_context(
            patch('ubuntu_image.__main__._logger.error'))
        code = main(
            ('classic', '--workdir', workdir, '--project', 'ubuntu-cpc',
             '--arch', 'armhf', self.classic_gadget_tree))
        self.assertEqual(code, 1)
        self.assertFalse(os.path.exists(os.path.join(workdir, 'success')))
        self.assertEqual(
            mock.call_args_list[-1],
            call('Required dependency qemu-arm-static seems to be missing. '
                 'Use UBUNTU_IMAGE_QEMU_USER_STATIC_PATH in case of '
                 'non-standard archs or custom paths.'))

    def test_hook_fired(self):
        # For the purpose of testing, we will be using the post-populate-rootfs
        # hook as we made sure it's still executed as part of of the
        # DoNothingBuilder.
        hookdir = self._resources.enter_context(TemporaryDirectory())
        hookfile = os.path.join(hookdir, 'post-populate-rootfs')
        # Let's make sure that, with the use of post-populate-rootfs, we can
        # modify the rootfs contents.
        with open(hookfile, 'w') as fp:
            fp.write("""\
#!/bin/sh
echo "[MAGIC_STRING_FOR_U-I_HOOKS]" > $UBUNTU_IMAGE_HOOK_ROOTFS/foo
""")
        os.chmod(hookfile, 0o744)
        workdir = self._resources.enter_context(TemporaryDirectory())
        self._resources.enter_context(
            patch('ubuntu_image.__main__.ModelAssertionBuilder',
                  DoNothingBuilder))
        code = main(('--hooks-directory', hookdir, '--workdir', workdir,
                     '--output-dir', workdir, self.model_assertion))
        self.assertEqual(code, 0)
        self.assertTrue(os.path.exists(os.path.join(workdir, 'root', 'foo')))
        imagefile = os.path.join(workdir, 'pc.img')
        self.assertTrue(os.path.exists(imagefile))
        # Map the image and grep through it to see if our hook change actually
        # landed in the final image.
        with open(imagefile, 'r+b') as fp:
            m = self._resources.enter_context(mmap(fp.fileno(), 0))
            self.assertGreaterEqual(m.find(b'[MAGIC_STRING_FOR_U-I_HOOKS]'), 0)

    def test_hook_error(self):
        # For the purpose of testing, we will be using the post-populate-rootfs
        # hook as we made sure it's still executed as part of of the
        # DoNothingBuilder.
        hookdir = self._resources.enter_context(TemporaryDirectory())
        hookfile = os.path.join(hookdir, 'post-populate-rootfs')
        with open(hookfile, 'w') as fp:
            fp.write("""\
#!/bin/sh
echo -n "Failed" 1>&2
return 1
""")
        os.chmod(hookfile, 0o744)
        self._resources.enter_context(
            patch('ubuntu_image.__main__.ModelAssertionBuilder',
                  DoNothingBuilder))
        mock = self._resources.enter_context(
            patch('ubuntu_image.__main__._logger.error'))
        code = main(('--hooks-directory', hookdir, self.model_assertion))
        self.assertEqual(code, 1)
        self.assertEqual(
            mock.call_args_list[-1],
            call('Hook script in path {} failed for the post-populate-rootfs '
                 'hook with return code 1. Output of stderr:\nFailed'.format(
                     hookfile)))

    def test_hook_fired_after_resume(self):
        # For the purpose of testing, we will be using the post-populate-rootfs
        # hook as we made sure it's still executed as part of of the
        # DoNothingBuilder.
        hookdir = self._resources.enter_context(TemporaryDirectory())
        hookfile = os.path.join(hookdir, 'post-populate-rootfs')
        with open(hookfile, 'w') as fp:
            fp.write("""\
#!/bin/sh
touch {}/success
""".format(hookdir))
        os.chmod(hookfile, 0o744)
        workdir = self._resources.enter_context(TemporaryDirectory())
        self._resources.enter_context(
            patch('ubuntu_image.__main__.ModelAssertionBuilder',
                  DoNothingBuilder))
        main(('--until', 'prepare_image', '--hooks-directory', hookdir,
              '--workdir', workdir, self.model_assertion))
        self.assertFalse(os.path.exists(os.path.join(hookdir, 'success')))
        # Check if after a resume the hook path is still correct and the hooks
        # are fired as expected.
        code = main(('--workdir', workdir, '--resume'))
        self.assertEqual(code, 0)
        self.assertTrue(os.path.exists(os.path.join(hookdir, 'success')))

    @skipIf('UBUNTU_IMAGE_TESTS_NO_NETWORK' in os.environ,
            'Cannot run this test without network access')
    def test_hook_official_support(self):
        # This test is responsible for checking if all the officially declared
        # hooks are called as intended, making sure none get dropped by
        # accident.
        self._resources.enter_context(
            patch('ubuntu_image.__main__.ModelAssertionBuilder',
                  XXXModelAssertionBuilder))
        fire_mock = self._resources.enter_context(
            patch('ubuntu_image.hooks.HookManager.fire'))
        code = main(('--channel', 'edge', self.model_assertion))
        self.assertEqual(code, 0)
        called_hooks = [x[0][0] for x in fire_mock.call_args_list]
        self.assertListEqual(called_hooks, supported_hooks)
示例#46
0
 def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     self._detach_stack = ExitStack()
     self._orig_checkpoint = self.scm.get_ref(EXEC_CHECKPOINT)
def unite_blocks(counter,
                 blocks_dir,
                 res_dir,
                 encoding,
                 encode=to_gamma_str,
                 block_len=4):
    with ExitStack() as stack:
        files = [
            stack.enter_context(open(f, 'r', encoding=encoding)) for f in
            [blocks_dir + 'block' + str(i) + '.txt' for i in range(counter)]
        ]

        queue = PriorityQueue()
        for i, f in enumerate(files):
            queue.put(to_set(f.readline().strip(), i))

        min_term = ''
        min_term_ids = []
        min_term_freq = 0

        with open(res_dir + '/vocab.txt', 'w', encoding=encoding) as v_file, \
                open(res_dir + '/postings.txt', 'w') as p_file, \
                open(res_dir + '/gamma_vocab.txt', 'wb') as gv_file, \
                open(res_dir + '/gamma_postings.txt', 'wb') as gp_file, \
                open(res_dir + '/table.txt', 'wb') as table_file:
            min_term_id = 0
            block = []
            table_builder = IndexBuilder(encode, encoding)

            def save():
                pointer_bytes, term_str, posting_bytes = table_builder.add_block(
                    block)
                gv_file.write(term_str)
                gp_file.write(posting_bytes)
                table_file.write(pointer_bytes)
                for part in block:
                    v_file.write(part[0] + '\n')
                    p_file.write(','.join(str(id) for id in part[2]) + '\n')

            while queue.qsize():
                min_set = queue.get()
                (term, f_id, freq, ids) = min_set

                if term == min_term:
                    if min_term_ids[-1] == ids[0]:
                        ids = ids[1:]
                    min_term_ids.extend(ids)
                    min_term_freq += freq

                else:
                    if min_term:
                        block.append([min_term, min_term_freq, min_term_ids])
                        if len(block) == block_len:
                            save()
                            block = []
                        min_term_id += 1
                    min_term = term
                    min_term_ids = ids
                    min_term_freq = freq
                next_line = files[f_id].readline().strip()
                if next_line:
                    queue.put(to_set(next_line, f_id))

            block.append([min_term, min_term_freq, min_term_ids])
            save()
            end_pointer, last_byte = table_builder.get_last_byte()
            gp_file.write(last_byte)
            table_file.write(end_pointer)
 def multi_contexts():
     with ExitStack() as stack:
         yield [stack.enter_context(ctx()) for ctx in contexts]
示例#49
0
class TestMain(TestCase):
    def setUp(self):
        super().setUp()
        self._resources = ExitStack()
        self.addCleanup(self._resources.close)
        # Capture builtin print() output.
        self._stdout = StringIO()
        self._stderr = StringIO()
        self._resources.enter_context(
            patch('argparse._sys.stdout', self._stdout))
        # Capture stderr since this is where argparse will spew to.
        self._resources.enter_context(
            patch('argparse._sys.stderr', self._stderr))

    def test_help(self):
        with self.assertRaises(SystemExit) as cm:
            main(('--help', ))
        self.assertEqual(cm.exception.code, 0)
        lines = self._stdout.getvalue().splitlines()
        self.assertTrue(lines[0].startswith('Usage'), lines[0])
        self.assertTrue(lines[1].startswith('  ubuntu-image'), lines[1])

    def test_debug(self):
        with ExitStack() as resources:
            mock = resources.enter_context(
                patch('ubuntu_image.__main__.logging.basicConfig'))
            resources.enter_context(
                patch('ubuntu_image.__main__.ModelAssertionBuilder',
                      EarlyExitModelAssertionBuilder))
            # Prevent actual main() from running.
            resources.enter_context(patch('ubuntu_image.__main__.main'))
            code = main(('--debug', 'model.assertion'))
        self.assertEqual(code, 0)
        mock.assert_called_once_with(level=logging.DEBUG)

    def test_no_debug(self):
        with ExitStack() as resources:
            mock = resources.enter_context(
                patch('ubuntu_image.__main__.logging.basicConfig'))
            resources.enter_context(
                patch('ubuntu_image.__main__.ModelAssertionBuilder',
                      EarlyExitModelAssertionBuilder))
            # Prevent actual main() from running.
            resources.enter_context(patch('ubuntu_image.__main__.main'))
            code = main(('model.assertion', ))
        self.assertEqual(code, 0)
        mock.assert_not_called()

    def test_state_machine_exception(self):
        with ExitStack() as resources:
            resources.enter_context(
                patch('ubuntu_image.__main__.ModelAssertionBuilder',
                      CrashingModelAssertionBuilder))
            mock = resources.enter_context(
                patch('ubuntu_image.__main__._logger.exception'))
            code = main(('model.assertion', ))
            self.assertEqual(code, 1)
            self.assertEqual(mock.call_args_list[-1],
                             call('Crash in state machine'))

    def test_state_machine_snap_command_fails(self):
        # The `snap prepare-image` command fails and main exits with non-zero.
        #
        # This tests needs to run the actual snap() helper function, not
        # the testsuite-wide mock.  This is appropriate since we're
        # mocking it ourselves here.
        if NosePlugin.snap_mocker is not None:
            NosePlugin.snap_mocker.patcher.stop()
            self._resources.callback(NosePlugin.snap_mocker.patcher.start)
        self._resources.enter_context(
            patch('ubuntu_image.helpers.subprocess_run',
                  return_value=SimpleNamespace(
                      returncode=1,
                      stdout='command stdout',
                      stderr='command stderr',
                      check_returncode=check_returncode,
                  )))
        self._resources.enter_context(LogCapture())
        self._resources.enter_context(
            patch('ubuntu_image.__main__.ModelAssertionBuilder',
                  XXXModelAssertionBuilder))
        workdir = self._resources.enter_context(TemporaryDirectory())
        imgfile = os.path.join(workdir, 'my-disk.img')
        code = main(
            ('--until', 'prepare_filesystems', '--channel', 'edge',
             '--workdir', workdir, '--output', imgfile, 'model.assertion'))
        self.assertEqual(code, 1)

    def test_no_arguments(self):
        with self.assertRaises(SystemExit) as cm:
            main(())
        self.assertEqual(cm.exception.code, 2)
        lines = self._stderr.getvalue().splitlines()
        self.assertTrue(
            lines[0].startswith('Warning: for backwards compatibility'),
            lines[0])
        self.assertTrue(lines[1], 'Usage:')
        self.assertEqual(lines[2], '  ubuntu-image COMMAND [OPTIONS]...')

    def test_with_none(self):
        with self.assertRaises(SystemExit) as cm:
            main((None))  # code coverage __main__.py 308-309
        self.assertEqual(cm.exception.code, 2)

    def test_snap_subcommand_help(self):
        with self.assertRaises(SystemExit) as cm:
            main((
                'snap',
                '--help',
            ))
        self.assertEqual(cm.exception.code, 0)
        lines = self._stdout.getvalue().splitlines()
        self.assertTrue(lines[0].startswith('usage: ubuntu-image snap'),
                        lines[0])

    def test_classic_subcommand_help(self):
        with self.assertRaises(SystemExit) as cm:
            main((
                'classic',
                '--help',
            ))
        self.assertEqual(cm.exception.code, 0)
        lines = self._stdout.getvalue().splitlines()
        self.assertTrue(lines[0].startswith('usage: ubuntu-image classic'),
                        lines[0])
示例#50
0
from hse2 import hse

from utility import lifecycle


def check_keys(cursor: hse.KvsCursor, expected: List[bytes]):
    actual = [k for k, _ in cursor.items()]
    assert len(actual) == len(expected)
    for x, y in zip(expected, actual):
        assert x == y


hse.init()

try:
    with ExitStack() as stack:
        kvdb_ctx = lifecycle.KvdbContext()
        kvdb = stack.enter_context(kvdb_ctx)
        kvs_ctx = lifecycle.KvsContext(kvdb, "nostale").rparams("transactions_enable=1")
        kvs = stack.enter_context(kvs_ctx)

        # Insert some keys
        with kvdb.transaction() as txn:
            kvs.put(b"a", b"1", txn=txn)
            kvs.put(b"b", b"2", txn=txn)
            kvs.put(b"c", b"3", txn=txn)
            kvs.put(b"d", b"4", txn=txn)

        # Begin three transactions
        txn1 = kvdb.transaction()
        txn1.begin()
示例#51
0
class TestConfig(TestCase):
    def setUp(self):
        super(TestConfig, self).setUp()
        self.name = "ca-certs"
        self.paths = None
        self.log = logging.getLogger("TestNoConfig")
        self.args = []

    def _fetch_distro(self, kind):
        cls = distros.fetch(kind)
        paths = helpers.Paths({})
        return cls(kind, {}, paths)

    def _mock_init(self):
        self.mocks = ExitStack()
        self.addCleanup(self.mocks.close)

        # Mock out the functions that actually modify the system
        self.mock_add = self.mocks.enter_context(
            mock.patch.object(cc_ca_certs, 'add_ca_certs'))
        self.mock_update = self.mocks.enter_context(
            mock.patch.object(cc_ca_certs, 'update_ca_certs'))
        self.mock_remove = self.mocks.enter_context(
            mock.patch.object(cc_ca_certs, 'remove_default_ca_certs'))

    def test_no_trusted_list(self):
        """
        Test that no certificates are written if the 'trusted' key is not
        present.
        """
        config = {"ca-certs": {}}

        for distro_name in cc_ca_certs.distros:
            self._mock_init()
            cloud = get_cloud(distro_name)
            cc_ca_certs.handle(self.name, config, cloud, self.log, self.args)

            self.assertEqual(self.mock_add.call_count, 0)
            self.assertEqual(self.mock_update.call_count, 1)
            self.assertEqual(self.mock_remove.call_count, 0)

    def test_empty_trusted_list(self):
        """Test that no certificate are written if 'trusted' list is empty."""
        config = {"ca-certs": {"trusted": []}}

        for distro_name in cc_ca_certs.distros:
            self._mock_init()
            cloud = get_cloud(distro_name)
            cc_ca_certs.handle(self.name, config, cloud, self.log, self.args)

            self.assertEqual(self.mock_add.call_count, 0)
            self.assertEqual(self.mock_update.call_count, 1)
            self.assertEqual(self.mock_remove.call_count, 0)

    def test_single_trusted(self):
        """Test that a single cert gets passed to add_ca_certs."""
        config = {"ca-certs": {"trusted": ["CERT1"]}}

        for distro_name in cc_ca_certs.distros:
            self._mock_init()
            cloud = get_cloud(distro_name)
            conf = cc_ca_certs._distro_ca_certs_configs(distro_name)
            cc_ca_certs.handle(self.name, config, cloud, self.log, self.args)

            self.mock_add.assert_called_once_with(conf, ['CERT1'])
            self.assertEqual(self.mock_update.call_count, 1)
            self.assertEqual(self.mock_remove.call_count, 0)

    def test_multiple_trusted(self):
        """Test that multiple certs get passed to add_ca_certs."""
        config = {"ca-certs": {"trusted": ["CERT1", "CERT2"]}}

        for distro_name in cc_ca_certs.distros:
            self._mock_init()
            cloud = get_cloud(distro_name)
            conf = cc_ca_certs._distro_ca_certs_configs(distro_name)
            cc_ca_certs.handle(self.name, config, cloud, self.log, self.args)

            self.mock_add.assert_called_once_with(conf, ['CERT1', 'CERT2'])
            self.assertEqual(self.mock_update.call_count, 1)
            self.assertEqual(self.mock_remove.call_count, 0)

    def test_remove_default_ca_certs(self):
        """Test remove_defaults works as expected."""
        config = {"ca-certs": {"remove-defaults": True}}

        for distro_name in cc_ca_certs.distros:
            self._mock_init()
            cloud = get_cloud(distro_name)
            cc_ca_certs.handle(self.name, config, cloud, self.log, self.args)

            self.assertEqual(self.mock_add.call_count, 0)
            self.assertEqual(self.mock_update.call_count, 1)
            self.assertEqual(self.mock_remove.call_count, 1)

    def test_no_remove_defaults_if_false(self):
        """Test remove_defaults is not called when config value is False."""
        config = {"ca-certs": {"remove-defaults": False}}

        for distro_name in cc_ca_certs.distros:
            self._mock_init()
            cloud = get_cloud(distro_name)
            cc_ca_certs.handle(self.name, config, cloud, self.log, self.args)

            self.assertEqual(self.mock_add.call_count, 0)
            self.assertEqual(self.mock_update.call_count, 1)
            self.assertEqual(self.mock_remove.call_count, 0)

    def test_correct_order_for_remove_then_add(self):
        """Test remove_defaults is not called when config value is False."""
        config = {"ca-certs": {"remove-defaults": True, "trusted": ["CERT1"]}}

        for distro_name in cc_ca_certs.distros:
            self._mock_init()
            cloud = get_cloud(distro_name)
            conf = cc_ca_certs._distro_ca_certs_configs(distro_name)
            cc_ca_certs.handle(self.name, config, cloud, self.log, self.args)

            self.mock_add.assert_called_once_with(conf, ['CERT1'])
            self.assertEqual(self.mock_update.call_count, 1)
            self.assertEqual(self.mock_remove.call_count, 1)
示例#52
0
class Workspace:
    def __init__(self, workspace_load_target):
        from .cli_target import WorkspaceLoadTarget

        self._stack = ExitStack()

        self._workspace_load_target = check.opt_inst_param(
            workspace_load_target, "workspace_load_target", WorkspaceLoadTarget
        )
        self._grpc_server_registry = self._stack.enter_context(
            ProcessGrpcServerRegistry(reload_interval=0, heartbeat_ttl=30)
        )
        self._load_workspace()

    def _load_workspace(self):
        repository_location_origins = (
            self._workspace_load_target.create_origins() if self._workspace_load_target else []
        )

        self._location_origin_dict = OrderedDict()
        check.list_param(
            repository_location_origins,
            "repository_location_origins",
            of_type=RepositoryLocationOrigin,
        )

        self._location_dict = {}
        self._location_error_dict = {}
        for origin in repository_location_origins:
            check.invariant(
                self._location_origin_dict.get(origin.location_name) is None,
                'Cannot have multiple locations with the same name, got multiple "{name}"'.format(
                    name=origin.location_name,
                ),
            )

            self._location_origin_dict[origin.location_name] = origin
            self._load_location(origin.location_name)

    # Can be overidden in subclasses that need different logic for loading repository
    # locations from origins
    def create_location_from_origin(self, origin):
        if not self._grpc_server_registry.supports_origin(origin):  # pylint: disable=no-member
            return origin.create_location()
        else:
            endpoint = self._grpc_server_registry.reload_grpc_endpoint(  # pylint: disable=no-member
                origin
            )
            return GrpcServerRepositoryLocation(
                origin=origin,
                server_id=endpoint.server_id,
                port=endpoint.port,
                socket=endpoint.socket,
                host=endpoint.host,
                heartbeat=True,
                watch_server=False,
                grpc_server_registry=self._grpc_server_registry,
            )

    def _load_location(self, location_name):
        if self._location_dict.get(location_name):
            del self._location_dict[location_name]

        if self._location_error_dict.get(location_name):
            del self._location_error_dict[location_name]

        origin = self._location_origin_dict[location_name]
        try:
            location = self.create_location_from_origin(origin)
            self._location_dict[location_name] = location
        except Exception:  # pylint: disable=broad-except
            error_info = serializable_error_info_from_exc_info(sys.exc_info())
            self._location_error_dict[location_name] = error_info
            warnings.warn(
                "Error loading repository location {location_name}:{error_string}".format(
                    location_name=location_name, error_string=error_info.to_string()
                )
            )

    def create_snapshot(self):
        return WorkspaceSnapshot(
            self._location_origin_dict.copy(), self._location_error_dict.copy()
        )

    @property
    def repository_locations(self):
        return list(self._location_dict.values())

    def has_repository_location(self, location_name):
        check.str_param(location_name, "location_name")
        return location_name in self._location_dict

    def get_repository_location(self, location_name):
        check.str_param(location_name, "location_name")
        return self._location_dict[location_name]

    def has_repository_location_error(self, location_name):
        check.str_param(location_name, "location_name")
        return location_name in self._location_error_dict

    def reload_repository_location(self, location_name):
        self._load_location(location_name)

    def reload_workspace(self):
        for location in self.repository_locations:
            location.cleanup()
        self._load_workspace()

    def __enter__(self):
        return self

    def __exit__(self, exception_type, exception_value, traceback):
        for location in self.repository_locations:
            location.cleanup()
        self._stack.close()
示例#53
0
def test_show(unittest, builtin_pkg):
    from dtale.app import show, get_instance, instances
    import dtale.views as views
    import dtale.global_state as global_state

    class MockDtaleFlask(Flask):

        def __init__(self, import_name, reaper_on=True, url=None, *args, **kwargs):
            kwargs.pop('instance_relative_config', None)
            kwargs.pop('static_url_path', None)
            super(MockDtaleFlask, self).__init__(import_name, *args, **kwargs)

        def run(self, *args, **kwargs):
            pass

    instances()
    test_data = pd.DataFrame([dict(a=1, b=2)])
    with ExitStack() as stack:
        mock_run = stack.enter_context(mock.patch('dtale.app.DtaleFlask.run', mock.Mock()))
        mock_find_free_port = stack.enter_context(mock.patch('dtale.app.find_free_port', mock.Mock(return_value=9999)))
        stack.enter_context(mock.patch('socket.gethostname', mock.Mock(return_value='localhost')))
        stack.enter_context(mock.patch('dtale.app.is_up', mock.Mock(return_value=False)))
        mock_requests = stack.enter_context(mock.patch('requests.get', mock.Mock()))
        instance = show(data=test_data, subprocess=False, name='foo', ignore_duplicate=True)
        assert 'http://localhost:9999' == instance._url
        mock_run.assert_called_once()
        mock_find_free_port.assert_called_once()

        pdt.assert_frame_equal(instance.data, test_data)
        tmp = test_data.copy()
        tmp['biz'] = 2.5
        instance.data = tmp
        unittest.assertEqual(
            global_state.DTYPES[instance._data_id],
            views.build_dtypes_state(tmp),
            'should update app data/dtypes'
        )

        instance2 = get_instance(instance._data_id)
        assert instance2._url == instance._url
        instances()

        assert get_instance(20) is None  # should return None for invalid data ids

        instance.kill()
        mock_requests.assert_called_once()
        assert mock_requests.call_args[0][0] == 'http://localhost:9999/shutdown'
        assert global_state.METADATA['1']['name'] == 'foo'

    with ExitStack() as stack:
        mock_run = stack.enter_context(mock.patch('dtale.app.DtaleFlask.run', mock.Mock()))
        mock_find_free_port = stack.enter_context(mock.patch('dtale.app.find_free_port', mock.Mock(return_value=9999)))
        stack.enter_context(mock.patch('socket.gethostname', mock.Mock(return_value='localhost')))
        stack.enter_context(mock.patch('dtale.app.is_up', mock.Mock(return_value=False)))
        mock_data_loader = mock.Mock(return_value=test_data)
        instance = show(data_loader=mock_data_loader, subprocess=False, port=9999, force=True, debug=True,
                        ignore_duplicate=True)
        assert 'http://localhost:9999' == instance._url
        mock_run.assert_called_once()
        mock_find_free_port.assert_not_called()
        mock_data_loader.assert_called_once()
        _, kwargs = mock_run.call_args

        assert '9999' in instance._url

    with ExitStack() as stack:
        mock_run = stack.enter_context(mock.patch('dtale.app.DtaleFlask.run', mock.Mock()))
        stack.enter_context(mock.patch('dtale.app.find_free_port', mock.Mock(return_value=9999)))
        stack.enter_context(mock.patch('socket.gethostname', mock.Mock(return_value='localhost')))
        stack.enter_context(mock.patch('dtale.app.is_up', mock.Mock(return_value=True)))
        mock_data_loader = mock.Mock(return_value=test_data)
        mock_webbrowser = stack.enter_context(mock.patch('webbrowser.get'))
        instance = show(data_loader=mock_data_loader, subprocess=False, port=9999, open_browser=True,
                        ignore_duplicate=True)
        mock_run.assert_not_called()
        webbrowser_instance = mock_webbrowser.return_value
        assert 'http://localhost:9999/dtale/main/3' == webbrowser_instance.open.call_args[0][0]
        instance.open_browser()
        assert 'http://localhost:9999/dtale/main/3' == webbrowser_instance.open.mock_calls[1][1][0]

    # RangeIndex test
    test_data = pd.DataFrame([1, 2, 3])
    with ExitStack() as stack:
        stack.enter_context(mock.patch('dtale.app.DtaleFlask', MockDtaleFlask))
        stack.enter_context(mock.patch('dtale.app.find_free_port', mock.Mock(return_value=9999)))
        stack.enter_context(mock.patch('socket.gethostname', mock.Mock(return_value='localhost')))
        stack.enter_context(mock.patch('dtale.app.is_up', mock.Mock(return_value=False)))
        stack.enter_context(mock.patch('dtale.app.logger', mock.Mock()))
        instance = show(data=test_data, subprocess=False, name='foo', ignore_duplicate=True)
        assert np.array_equal(instance.data['0'].values, test_data[0].values)

    with ExitStack() as stack:
        stack.enter_context(mock.patch('dtale.app.DtaleFlask', MockDtaleFlask))
        stack.enter_context(mock.patch('dtale.app.find_free_port', mock.Mock(return_value=9999)))
        stack.enter_context(mock.patch('socket.gethostname', mock.Mock(return_value='localhost')))
        stack.enter_context(mock.patch('dtale.app.is_up', mock.Mock(return_value=False)))
        stack.enter_context(mock.patch('dtale.app.logger', mock.Mock()))
        stack.enter_context(mock.patch('dtale.views.in_ipython_frontend', mock.Mock(return_value=False)))

        get_calls = {'ct': 0}
        getter = namedtuple('get', 'ok')

        def mock_requests_get(url, verify=True):
            if url.endswith('/health'):
                is_ok = get_calls['ct'] > 0
                get_calls['ct'] += 1
                return getter(is_ok)
            return getter(True)
        stack.enter_context(mock.patch('requests.get', mock_requests_get))
        mock_display = stack.enter_context(mock.patch('IPython.display.display', mock.Mock()))
        mock_iframe = stack.enter_context(mock.patch('IPython.display.IFrame', mock.Mock()))
        instance = show(data=test_data, subprocess=True, name='foo', notebook=True, ignore_duplicate=True)
        mock_display.assert_called_once()
        mock_iframe.assert_called_once()
        assert mock_iframe.call_args[0][0] == 'http://localhost:9999/dtale/iframe/5'

        assert type(instance.__str__()).__name__ == 'str'
        assert type(instance.__repr__()).__name__ == 'str'

    class MockDtaleFlaskRunTest(Flask):

        def __init__(self, import_name, reaper_on=True, url=None, *args, **kwargs):
            kwargs.pop('instance_relative_config', None)
            kwargs.pop('static_url_path', None)
            super(MockDtaleFlaskRunTest, self).__init__(import_name, *args, **kwargs)

        def run(self, *args, **kwargs):
            assert self.jinja_env.auto_reload
            assert self.config['TEMPLATES_AUTO_RELOAD']

    with mock.patch('dtale.app.DtaleFlask', MockDtaleFlaskRunTest):
        show(data=test_data, subprocess=False, port=9999, debug=True, ignore_duplicate=True)

    with mock.patch('dtale.app._thread.start_new_thread', mock.Mock()) as mock_thread:
        show(data=test_data, subprocess=True, ignore_duplicate=True)
        mock_thread.assert_called()

    test_data = pd.DataFrame([dict(a=1, b=2)])

    with ExitStack() as stack:
        mock_build_app = stack.enter_context(mock.patch('dtale.app.build_app', mock.Mock()))
        stack.enter_context(mock.patch('dtale.app.find_free_port', mock.Mock(return_value=9999)))
        stack.enter_context(mock.patch('socket.gethostname', mock.Mock(return_value='localhost')))
        stack.enter_context(mock.patch('dtale.app.is_up', mock.Mock(return_value=False)))
        stack.enter_context(mock.patch('requests.get', mock.Mock()))
        show(data=test_data, subprocess=False, name='foo', ignore_duplicate=True)

        _, kwargs = mock_build_app.call_args
        unittest.assertEqual(
            {'host': 'localhost', 'reaper_on': True}, kwargs, 'build_app should be called with defaults'
        )

    # test adding duplicate column
    with ExitStack() as stack:
        stack.enter_context(mock.patch('dtale.app.DtaleFlask', MockDtaleFlask))
        stack.enter_context(mock.patch('dtale.app.find_free_port', mock.Mock(return_value=9999)))
        stack.enter_context(mock.patch('socket.gethostname', mock.Mock(return_value='localhost')))
        stack.enter_context(mock.patch('dtale.app.is_up', mock.Mock(return_value=False)))
        stack.enter_context(mock.patch('requests.get', mock.Mock()))
        instance = show(data=pd.DataFrame([dict(a=1, b=2)]), subprocess=False, name='foo',
                        ignore_duplicate=True)
        with pytest.raises(Exception):
            instance.data = instance.data.rename(columns={'b': 'a'})

        curr_instance_ct = len(global_state.DATA)
        show(data=pd.DataFrame([dict(a=1, b=2)]), subprocess=False, name='foo')
        assert curr_instance_ct == len(global_state.DATA)

    # cleanup
    global_state.cleanup()
示例#54
0
def train(args: argparse.Namespace):
    # TODO: make training compatible with full net
    args.image_preextracted_features = True  # override this for now

    utils.seed_rngs(args.seed)

    check_arg_compatibility(args)
    output_folder = os.path.abspath(args.output)
    resume_training = check_resume(args, output_folder)

    global logger
    logger = setup_main_logger(__name__,
                               file_logging=True,
                               console=not args.quiet,
                               path=os.path.join(output_folder, C.LOG_NAME))
    utils.log_basic_info(args)
    with open(os.path.join(output_folder, C.ARGS_STATE_NAME), "w") as fp:
        json.dump(vars(args), fp)

    max_seq_len_source, max_seq_len_target = args.max_seq_len
    # The maximum length is the length before we add the BOS/EOS symbols
    max_seq_len_source = max_seq_len_source + C.SPACE_FOR_XOS
    max_seq_len_target = max_seq_len_target + C.SPACE_FOR_XOS
    logger.info(
        "Adjusting maximum length to reserve space for a BOS/EOS marker. New maximum length: (%d, %d)",
        max_seq_len_source, max_seq_len_target)

    with ExitStack() as exit_stack:
        context = utils.determine_context(
            device_ids=args.device_ids,
            use_cpu=args.use_cpu,
            disable_device_locking=args.disable_device_locking,
            lock_dir=args.lock_dir,
            exit_stack=exit_stack)
        if args.batch_type == C.BATCH_TYPE_SENTENCE:
            check_condition(
                args.batch_size % len(context) == 0,
                "When using multiple devices the batch size must be "
                "divisible by the number of devices. Choose a batch "
                "size that is a multiple of %d." % len(context))
        logger.info("Training Device(s): %s",
                    ", ".join(str(c) for c in context))

        # Read feature size
        if args.image_preextracted_features:
            _, args.source_image_size = read_feature_shape(args.source_root)

        train_iter, eval_iter, config_data, target_vocab = create_data_iters_and_vocab(
            args=args,
            max_seq_len_source=max_seq_len_source,
            max_seq_len_target=max_seq_len_target,
            resume_training=resume_training,
            output_folder=output_folder)
        max_seq_len_source = config_data.max_seq_len_source
        max_seq_len_target = config_data.max_seq_len_target

        # Dump the vocabularies if we're just starting up
        if not resume_training:
            vocab.vocab_to_json(target_vocab,
                                os.path.join(output_folder, C.VOCAB_TRG_NAME))

        target_vocab_size = len(target_vocab)
        logger.info("Vocabulary sizes: target=%d", target_vocab_size)

        model_config = create_model_config(
            args=args,
            vocab_target_size=target_vocab_size,
            max_seq_len_source=max_seq_len_source,
            max_seq_len_target=max_seq_len_target,
            config_data=config_data)
        model_config.freeze()

        training_model = create_training_model(config=model_config,
                                               context=context,
                                               output_dir=output_folder,
                                               train_iter=train_iter,
                                               args=args)

        # Handle options that override training settings
        min_updates = args.min_updates
        max_updates = args.max_updates
        min_samples = args.min_samples
        max_samples = args.max_samples
        max_num_checkpoint_not_improved = args.max_num_checkpoint_not_improved
        min_epochs = args.min_num_epochs
        max_epochs = args.max_num_epochs
        if min_epochs is not None and max_epochs is not None:
            check_condition(
                min_epochs <= max_epochs,
                "Minimum number of epochs must be smaller than maximum number of epochs"
            )
        # Fixed training schedule always runs for a set number of updates
        if args.learning_rate_schedule:
            min_updates = None
            max_updates = sum(num_updates
                              for (_,
                                   num_updates) in args.learning_rate_schedule)
            max_num_checkpoint_not_improved = -1
            min_samples = None
            max_samples = None
            min_epochs = None
            max_epochs = None

        # Get initialization from encoders (useful for pretrained models)
        extra_initializers = get_preinit_encoders(
            training_model.encoder.encoders)
        if len(extra_initializers) == 0:
            extra_initializers = None

        trainer = training.EarlyStoppingTrainer(
            model=training_model,
            optimizer_config=create_optimizer_config(args, [1.0],
                                                     extra_initializers),
            max_params_files_to_keep=args.keep_last_params,
            source_vocabs=[None],
            target_vocab=target_vocab)

        trainer.fit(train_iter=train_iter,
                    validation_iter=eval_iter,
                    early_stopping_metric=args.optimized_metric,
                    metrics=args.metrics,
                    checkpoint_frequency=args.checkpoint_frequency,
                    max_num_not_improved=max_num_checkpoint_not_improved,
                    min_samples=min_samples,
                    max_samples=max_samples,
                    min_updates=min_updates,
                    max_updates=max_updates,
                    min_epochs=min_epochs,
                    max_epochs=max_epochs,
                    lr_decay_param_reset=args.learning_rate_decay_param_reset,
                    lr_decay_opt_states_reset=args.
                    learning_rate_decay_optimizer_states_reset,
                    decoder=create_checkpoint_decoder(args, exit_stack,
                                                      context),
                    mxmonitor_pattern=args.monitor_pattern,
                    mxmonitor_stat_func=args.monitor_stat_func,
                    allow_missing_parameters=args.allow_missing_params,
                    existing_parameters=args.params)
    def __init__(self,
                 model_folder: str,
                 inputs: List[str],
                 references: List[str],
                 source_vocabs: List[vocab.Vocab],
                 target_vocabs: List[vocab.Vocab],
                 model: model.SockeyeModel,
                 device: torch.device,
                 max_input_len: Optional[int] = None,
                 batch_size: int = 16,
                 beam_size: int = C.DEFAULT_BEAM_SIZE,
                 nbest_size: int = C.DEFAULT_NBEST_SIZE,
                 bucket_width_source: int = 10,
                 length_penalty_alpha: float = 1.0,
                 length_penalty_beta: float = 0.0,
                 max_output_length_num_stds: int = C.DEFAULT_NUM_STD_MAX_OUTPUT_LENGTH,
                 ensemble_mode: str = 'linear',
                 sample_size: int = -1,
                 random_seed: int = 42) -> None:
        self.max_input_len = max_input_len
        self.max_output_length_num_stds = max_output_length_num_stds
        self.ensemble_mode = ensemble_mode
        self.beam_size = beam_size
        self.nbest_size = nbest_size
        self.batch_size = batch_size
        self.bucket_width_source = bucket_width_source
        self.length_penalty_alpha = length_penalty_alpha
        self.length_penalty_beta = length_penalty_beta
        self.model = model

        with ExitStack() as exit_stack:
            inputs_fins = [exit_stack.enter_context(utils.smart_open(f)) for f in inputs]
            references_fins = [exit_stack.enter_context(utils.smart_open(f)) for f in references]

            inputs_sentences = [f.readlines() for f in inputs_fins]
            targets_sentences = [f.readlines() for f in references_fins]

            utils.check_condition(all(len(l) == len(targets_sentences[0])
                                      for l in chain(inputs_sentences, targets_sentences)),
                                  "Sentences differ in length.")
            utils.check_condition(all(len(sentence.strip()) > 0 for sentence in targets_sentences[0]),
                                  "Empty target validation sentence.")

            if sample_size <= 0:
                sample_size = len(inputs_sentences[0])
            if sample_size < len(inputs_sentences[0]):
                sentences = parallel_subsample(
                    inputs_sentences + targets_sentences, sample_size, random_seed)
                self.inputs_sentences = sentences[0:len(inputs_sentences)]
                self.targets_sentences = sentences[len(inputs_sentences):]
            else:
                self.inputs_sentences, self.targets_sentences = inputs_sentences, targets_sentences

            if sample_size < self.batch_size:
                self.batch_size = sample_size
        for factor_idx, factor in enumerate(self.inputs_sentences):
            write_to_file(factor, os.path.join(model_folder, C.DECODE_IN_NAME.format(factor=factor_idx)))
        for factor_idx, factor in enumerate(self.targets_sentences):
            write_to_file(factor, os.path.join(model_folder, C.DECODE_REF_NAME.format(factor=factor_idx)))

        self.inputs_sentences = list(zip(*self.inputs_sentences))  # type: ignore

        scorer = inference.CandidateScorer(
            length_penalty_alpha=length_penalty_alpha,
            length_penalty_beta=length_penalty_beta,
            brevity_penalty_weight=0.0)

        self.translator = inference.Translator(
            batch_size=self.batch_size,
            device=device,
            ensemble_mode=self.ensemble_mode,
            scorer=scorer,
            beam_search_stop='all',
            nbest_size=self.nbest_size,
            models=[self.model],
            source_vocabs=source_vocabs,
            target_vocabs=target_vocabs,
            restrict_lexicon=None)

        logger.info("Created CheckpointDecoder(max_input_len=%d, beam_size=%d, num_sentences=%d)",
                    max_input_len if max_input_len is not None else -1, beam_size, len(self.targets_sentences[0]))
示例#56
0
def run_translate(args: argparse.Namespace):

    if args.output is not None:
        global logger
        logger = setup_main_logger(__name__,
                                   console=not args.quiet,
                                   file_logging=True,
                                   path="%s.%s" % (args.output, C.LOG_NAME))

    if args.checkpoints is not None:
        check_condition(len(args.checkpoints) == len(args.models), "must provide checkpoints for each model")

    if args.skip_topk:
        check_condition(args.beam_size == 1, "--skip-topk has no effect if beam size is larger than 1")
        check_condition(len(args.models) == 1, "--skip-topk has no effect for decoding with more than 1 model")

    log_basic_info(args)

    output_handler = get_output_handler(args.output_type,
                                        args.output,
                                        args.sure_align_threshold)

    with ExitStack() as exit_stack:
        check_condition(len(args.device_ids) == 1, "translate only supports single device for now")
        context = determine_context(device_ids=args.device_ids,
                                    use_cpu=args.use_cpu,
                                    disable_device_locking=args.disable_device_locking,
                                    lock_dir=args.lock_dir,
                                    exit_stack=exit_stack)[0]
        logger.info("Translate Device: %s", context)

        if args.override_dtype == C.DTYPE_FP16:
            logger.warning('Experimental feature \'--override-dtype float16\' has been used. '
                           'This feature may be removed or change its behaviour in future. '
                           'DO NOT USE IT IN PRODUCTION!')

        models, source_vocabs, target_vocab = inference.load_models(
            context=context,
            max_input_len=args.max_input_len,
            beam_size=args.beam_size,
            batch_size=args.batch_size,
            model_folders=args.models,
            checkpoints=args.checkpoints,
            softmax_temperature=args.softmax_temperature,
            max_output_length_num_stds=args.max_output_length_num_stds,
            decoder_return_logit_inputs=args.restrict_lexicon is not None,
            cache_output_layer_w_b=args.restrict_lexicon is not None,
            override_dtype=args.override_dtype)
        restrict_lexicon = None  # type: Optional[TopKLexicon]
        if args.restrict_lexicon:
            restrict_lexicon = TopKLexicon(source_vocabs[0], target_vocab)
            restrict_lexicon.load(args.restrict_lexicon, k=args.restrict_lexicon_topk)
        store_beam = args.output_type == C.OUTPUT_HANDLER_BEAM_STORE

        inference_adapt_model = None
        if args.inference_adapt:
            model = models[0] # for now, just use the first of the loaded models
            bucketing = True # for now, just set this here; modify decode CLI args later
            model_config = model.config
            default_bucket_key = (model_config.config_data.max_seq_len_source, model_config.config_data.max_seq_len_target)
            provide_data = [mx.io.DataDesc(name=C.SOURCE_NAME,
                           shape=(args.batch_size, default_bucket_key[0], model_config.config_data.num_source_factors),
                           layout=C.BATCH_MAJOR),
            mx.io.DataDesc(name=C.TARGET_NAME,
                           shape=(args.batch_size, default_bucket_key[1]),
                           layout=C.BATCH_MAJOR)]
            provide_label = [mx.io.DataDesc(name=C.TARGET_LABEL_NAME,
                           shape=(args.batch_size, default_bucket_key[1]),
                           layout=C.BATCH_MAJOR)]
            inference_adapt_model = inference_adapt_train.create_inference_adapt_model(config=model_config,
                                                                       context=context,
                                                                       provide_data=provide_data,
                                                                       provide_label=provide_label,
                                                                       default_bucket_key=default_bucket_key,
                                                                       bucketing=bucketing,
                                                                       args=args)
        translator = inference.Translator(context=context,
                                          ensemble_mode=args.ensemble_mode,
                                          bucket_source_width=args.bucket_width,
                                          length_penalty=inference.LengthPenalty(args.length_penalty_alpha,
                                                                                 args.length_penalty_beta),
                                          beam_prune=args.beam_prune,
                                          beam_search_stop=args.beam_search_stop,
                                          models=models,
                                          source_vocabs=source_vocabs,
                                          target_vocab=target_vocab,
                                          inference_adapt_model=inference_adapt_model,
                                          restrict_lexicon=restrict_lexicon,
                                          avoid_list=args.avoid_list,
                                          store_beam=store_beam,
                                          strip_unknown_words=args.strip_unknown_words,
                                          skip_topk=args.skip_topk,
                                          adapt_args=args)
        read_and_translate(translator=translator,
                           output_handler=output_handler,
                           chunk_size=args.chunk_size,
                           input_file=args.input,
                           input_factors=args.input_factors,
                           input_is_json=args.json_input)
示例#57
0
def main():
    parser = argparse.ArgumentParser(
        description='Aggregate Agilent data from multiple CSV files.')
    parser.add_argument('config_file_path',
                        nargs='?',
                        metavar='config',
                        default='config.json',
                        help='Path to the config json file')
    args = parser.parse_args()

    configs = {}
    with open(args.config_file_path) as config_json_file:
        configs = json.load(config_json_file)

    name_map = get_name_map(configs["name_map_file"],
                            configs["name_start_line"])
    common_headers = set(configs['common_headers'])
    additional_headers = set(configs['additional_headers'])
    common_columns = set([])
    additional_columns = set([])
    supported_transforms = {'log': log}
    if configs['transform'] and configs['transform'].lower(
    ) in supported_transforms.keys():
        transform = supported_transforms[configs['transform']]

    # Open the write context first, so we always have access to the output
    with open(configs['output_file'], 'w', newline='\n') as output_file:
        csv_writer = csv.writer(output_file, delimiter=',')
        # Then open an ExitStack of read contexts to open all readers
        with ExitStack() as stack:
            source_file_names = [
                os.path.splitext(os.path.basename(filename))[0]
                for filename in configs['source_files']
            ]
            input_files = [
                stack.enter_context(open(filename, 'r', newline='\n'))
                for filename in configs['source_files']
            ]
            csv_readers = [
                csv.reader(csv_file, delimiter='\t')
                for csv_file in input_files
            ]
            row_num = 0
            to_write = []
            first_reader = csv_readers.pop(0)
            first_file_name = source_file_names.pop(0)
            for row in first_reader:
                to_write = []
                row_num += 1
                if row_num < configs["start_line"]:
                    # Make sure we move through the other csv files too
                    for csv_reader in csv_readers:
                        next(csv_reader)
                    continue

                for i in range(len(row)):
                    # If we're at the header row, set which columns we need.
                    # This is our first file, so set the common headers and the
                    # additional headers.
                    if row_num == configs["start_line"]:
                        if row[i] in common_headers:
                            common_columns.add(i)
                            to_write.append(row[i])
                        elif row[i] in additional_headers:
                            additional_columns.add(i)
                            # Map name, and write new name here
                            to_write.append(name_map[first_file_name] + '_' +
                                            row[i])
                    else:
                        if i in common_columns:
                            to_write.append(row[i])
                        elif i in additional_columns:
                            if transform:
                                to_write.append(transform(row[i]))
                            else:
                                to_write.append(row[i])

                for i in range(len(csv_readers)):
                    csv_reader = csv_readers[i]
                    filename = source_file_names[i]
                    additional_row = next(csv_reader)
                    for j in range(len(additional_row)):
                        # If we're at the header row, set which columns we need.
                        # Here we just need the new line
                        if row_num == configs["start_line"]:
                            if additional_row[j] in additional_headers:
                                to_write.append(name_map[filename] + '_' +
                                                additional_row[j])
                        else:
                            if j in additional_columns:
                                if transform:
                                    to_write.append(
                                        transform(additional_row[j]))
                                else:
                                    to_write.append(additional_row[j])
                csv_writer.writerow(to_write)
def run(agent_modules, player_names, config=None, recorder=None, watch=False):
    # Create a new game
    row_count = config.get('rows')
    column_count = config.get('columns')
    iteration_limit = config.get('max_iterations')
    is_interactive = config.get('interactive')

    # Check max number of players support by the map:
    squers_per_player = 6
    max_players = row_count * column_count / squers_per_player
    if max_players < len(agent_modules):
        raise TooManyPlayers(
            f"Game map ({column_count}x{row_count}) supports at most {max_players} players while {len(agent_modules)} agent requested."
        )

    # Load agent modules
    with ExitStack() as stack:
        agent_drivers = __load_agent_drivers(stack,
                                             agent_modules,
                                             watch=watch,
                                             config=config)
        if not agent_drivers:
            return None  # Exiting with an error, no contest

        game = Game(row_count=row_count,
                    column_count=column_count,
                    max_iterations=iteration_limit,
                    recorder=recorder)

        # Add all agents to the game
        agents: List[AgentProxy] = []
        names_len = len(player_names) if player_names else 0
        for i, agent_driver in enumerate(agent_drivers):
            agent = agent_driver.agent()
            agents.append(agent)
            game.add_agent(
                agent, player_names[i] if i < names_len else agent_driver.name)

        # Add a player for the user if running in interactive mode or configured interactive
        user_pid = game.add_player("Player") if is_interactive else None
        game.generate_map()

        wait_time = AGENT_READY_WAIT_TIMEOUT
        time.sleep(
            0.1
        )  # Yeld to sub-processes a chance to start and initialise agents
        agents_not_ready = [a.name for a in agents if not a.is_ready]
        while agents_not_ready and wait_time > 0:
            logger.info(
                f"Waiting for slowpoke agents [{wait_time} sec]: {agents_not_ready}"
            )
            time.sleep(AGENT_READY_WAIT_SEC)
            wait_time -= AGENT_READY_WAIT_SEC
            agents_not_ready = [a.name for a in agents if not a.is_ready]

        if agents_not_ready:
            logger.info(
                f"Agents {agents_not_ready} are still not ready even after {AGENT_READY_WAIT_TIMEOUT}sec. Starting the match anyways"
            )

        tick_step = config.get('tick_step')
        if config.get('headless'):
            from .headless_client import Client

            client = Client(game=game, config=config)
            client.run(tick_step)
        else:
            if config.get('hack'):
                from .hack_client import Client
                screen_width = 80
                screen_height = 24
            else:
                from .arcade_client import Client, WIDTH, HEIGHT, PADDING

                screen_width = PADDING[0] * 2 + WIDTH * 12
                screen_height = PADDING[1] * 3 + HEIGHT * 10

            window = Client(width=screen_width,
                            height=screen_height,
                            title=SCREEN_TITLE,
                            game=game,
                            config=config,
                            interactive=is_interactive,
                            user_pid=user_pid)
            window.run(tick_step)

        # Announce game winner and exit
        return game.stats
示例#59
0
def run(run_dir, file_in, file_out, ignore_file_type, data_server):
    """
    Runs the main process of the getmodel process.

    Args:
        run_dir: (str) the directory of where the process is running
        file_in: (Optional[str]) the path to the input directory
        file_out: (Optional[str]) the path to the output directory
        ignore_file_type: set(str) file extension to ignore when loading
        data_server: (bool) if set to True runs the data server

    Returns: None
    """
    static_path = os.path.join(run_dir, 'static')
    input_path = os.path.join(run_dir, 'input')
    ignore_file_type = set(ignore_file_type)

    if data_server:
        logger.debug("data server active")
        FootprintLayerClient.register()
        logger.debug("registered with data server")
        atexit.register(FootprintLayerClient.unregister)
    else:
        logger.debug("data server not active")

    with ExitStack() as stack:
        if file_in is None:
            streams_in = sys.stdin.buffer
        else:
            streams_in = stack.enter_context(open(file_in, 'rb'))

        if file_out is None:
            stream_out = sys.stdout.buffer
        else:
            stream_out = stack.enter_context(open(file_out, 'wb'))

        event_id_mv = memoryview(bytearray(4))
        event_ids = np.ndarray(1, buffer=event_id_mv, dtype='i4')

        logger.debug('init items')
        vuln_dict, areaperil_to_vulns_idx_dict, areaperil_to_vulns_idx_array, areaperil_to_vulns = get_items(
            input_path, ignore_file_type)

        logger.debug('init footprint')
        footprint_obj = stack.enter_context(
            Footprint.load(static_path, ignore_file_type))

        if data_server:
            num_intensity_bins: int = FootprintLayerClient.get_number_of_intensity_bins(
            )
            logger.info(f"got {num_intensity_bins} intensity bins from server")
        else:
            num_intensity_bins: int = footprint_obj.num_intensity_bins

        logger.debug('init vulnerability')

        vuln_array, vulns_id, num_damage_bins = get_vulns(
            static_path, vuln_dict, num_intensity_bins, ignore_file_type)
        convert_vuln_id_to_index(vuln_dict, areaperil_to_vulns)
        logger.debug('init mean_damage_bins')
        mean_damage_bins = get_mean_damage_bins(static_path, ignore_file_type)

        # even_id, areaperil_id, vulnerability_id, num_result, [oasis_float] * num_result
        max_result_relative_size = 1 + +areaperil_int_relative_size + 1 + 1 + num_damage_bins * results_relative_size

        mv = memoryview(bytearray(buff_size))

        int32_mv = np.ndarray(buff_size // np.int32().itemsize,
                              buffer=mv,
                              dtype=np.int32)

        # header
        stream_out.write(np.uint32(1).tobytes())

        logger.debug('doCdf starting')
        while True:
            len_read = streams_in.readinto(event_id_mv)
            if len_read == 0:
                break

            if data_server:
                event_footprint = FootprintLayerClient.get_event(event_ids[0])
            else:
                event_footprint = footprint_obj.get_event(event_ids[0])

            if event_footprint is not None:
                for cursor_bytes in doCdf(
                        event_ids[0], num_intensity_bins, event_footprint,
                        areaperil_to_vulns_idx_dict,
                        areaperil_to_vulns_idx_array, areaperil_to_vulns,
                        vuln_array, vulns_id, num_damage_bins,
                        mean_damage_bins, int32_mv, max_result_relative_size):

                    if cursor_bytes:
                        stream_out.write(mv[:cursor_bytes])
                    else:
                        break
示例#60
0
def run(agent_modules, player_names, config=None, recorder=None, watch=False):
    # Create a new game
    row_count = config.get('rows')
    column_count = config.get('columns')
    iteration_limit = config.get('max_iterations')
    is_interactive = config.get('interactive')

    # Check max number of players support by the map:
    squers_per_player = 6
    max_players = row_count * column_count / squers_per_player
    if max_players < len(agent_modules):
        raise TooManyPlayers(
            f"Game map ({column_count}x{row_count}) supports at most {max_players} players while {len(agent_modules)} agent requested."
        )

    # Load agent modules
    with ExitStack() as stack:
        agents = __load_agent_drivers(stack,
                                      agent_modules,
                                      watch=watch,
                                      config=config)
        if not agents:
            return None  # Exiting with an error, no contest

        game = Game(row_count=row_count,
                    column_count=column_count,
                    max_iterations=iteration_limit,
                    recorder=recorder)

        # Add all agents to the game
        names_len = len(player_names) if player_names else 0
        for i, agent_driver in enumerate(agents):
            game.add_agent(
                agent_driver.agent(),
                player_names[i] if i < names_len else agent_driver.name)

        # Add a player for the user if running in interactive mode or configured interactive
        user_pid = game.add_player("Player") if is_interactive else None

        game.generate_map()

        tick_step = config.get('tick_step')
        if config.get('headless'):
            from .headless_client import Client

            client = Client(game=game, config=config)
            client.run(tick_step)
        else:
            if config.get('hack'):
                from .hack_client import Client
                screen_width = 80
                screen_height = 24
            else:
                from .arcade_client import Client, WIDTH, HEIGHT, PADDING

                screen_width = PADDING[0] * 2 + WIDTH * 12
                screen_height = PADDING[1] * 3 + HEIGHT * 10

            window = Client(width=screen_width,
                            height=screen_height,
                            title=SCREEN_TITLE,
                            game=game,
                            config=config,
                            interactive=is_interactive,
                            user_pid=user_pid)
            window.run(tick_step)

        # Announce game winner and exit
        return game.stats