Beispiel #1
0
def do_autofocus(path_to_nis,
                 step_coarse=None,
                 step_fine=None,
                 focus_criterion=None,
                 focus_with_piezo=False):
    try:
        ntf = NamedTemporaryFile(suffix='.mac', delete=False)
        cmd = '''
        StgZ_SetActiveZ({});
        StgFocusSetCriterion({});
        StgFocusAdaptiveTwoPasses({},{});
        Freeze();
        '''.format(
            1 if focus_with_piezo else 0,
            focus_criterion
            if not focus_criterion is None else DEFAULT_FOCUS_CRITERION,
            step_coarse
            if not step_coarse is None else DEFAULT_FOCUS_STEP_COARSE,
            step_fine if not step_fine is None else DEFAULT_FOCUS_STEP_FINE,
        )
        ntf.writelines([bytes(cmd, 'utf-8')])
        ntf.close()
        subprocess.call(' '.join([quote(path_to_nis), '-mw', quote(ntf.name)]))
    finally:
        os.remove(ntf.name)
Beispiel #2
0
def api_call(endpoint,
             headers={},
             json_data={},
             method=requests.get,
             api_version='v1',
             limit=1000,
             offset=0,
             org_id=None,
             verbose=False):
    endpoint = "{}/{}".format(api_version, endpoint)
    # print("Endpoint:", endpoint)
    # print("Data:", json_data)

    temp_pem = NamedTemporaryFile(suffix='.pem')
    temp_key = NamedTemporaryFile(suffix='.key')

    pem_env_var = config('SEARCH-ADS-PEM')
    key_env_var = config('SEARCH-ADS-KEY')

    try:
        if pem_env_var.endswith('.pem'):  # env is the name of file
            call_kwargs = {
                "cert": (pem_env_var, key_env_var),
                "headers": headers,
            }
        else:  # env var is the key explicit
            pem_lines = pem_env_var.split("\\n")
            temp_pem.writelines([
                str("{}\n".format(str(item))).encode('utf-8')
                for item in pem_lines
            ])
            temp_pem.flush()  # ensure all data written

            key_lines = key_env_var.split("\\n")
            temp_key.writelines([
                str("{}\n".format(str(item))).encode('utf-8')
                for item in key_lines
            ])
            temp_key.flush()  # ensure all data written

            call_kwargs = {
                "cert": (temp_pem.name, temp_key.name),
                "headers": headers,
            }
        if json_data:
            call_kwargs['json'] = json_data
        if org_id:
            call_kwargs['headers']["Authorization"] = "orgId={org_id}".format(
                org_id=org_id)
        req = method(
            "https://api.searchads.apple.com/api/{endpoint}".format(
                endpoint=endpoint), **call_kwargs)
    finally:
        # Automatically cleans up the file
        temp_pem.close()
        temp_key.close()

    if verbose:
        print(req.text)
    return req.json()
Beispiel #3
0
def file_name():
    file = NamedTemporaryFile(mode="w")
    data = ["name=NAME\n", "new_1=NEW_1\n", "int_value=10\n"]
    file.writelines(data)
    file.seek(0)
    yield file.name
    file.close()
    def _certificate_helper(self, private_key=None):
        """
        Returns the path to a helper script which can be used in the GIT_SSH env
        var to use a custom private key file.
        """
        opts = {
            'StrictHostKeyChecking': 'no',
            'PasswordAuthentication': 'no',
            'KbdInteractiveAuthentication': 'no',
            'ChallengeResponseAuthentication': 'no',
        }

        # Create identity file
        identity = NamedTemporaryFile(delete=False)
        ecm.chmod(identity.name, 0600)
        identity.writelines([private_key])
        identity.close()

        # Create helper script
        helper = NamedTemporaryFile(delete=False)
        helper.writelines([
            '#!/bin/sh\n',
            'exec ssh ' + 
            ' '.join('-o%s=%s' % (key, value) for key, value in opts.items()) + 
            ' -i ' + identity.name + 
            ' $*\n'
        ])

        helper.close()
        ecm.chmod(helper.name, 0750)

        return helper.name, identity.name
Beispiel #5
0
    def test_two_files_are_different_as_string_is_true___result_is_concatenated_list_of_differences(
            self):
        first = NamedTemporaryFile(mode='w', delete=False)
        second = NamedTemporaryFile(mode='w', delete=False)
        try:
            first.writelines(
                ['HEADING\n', 'first\n', 'same\n', 'second\n', 'FOOTER\n'])
            first.close()

            second.writelines(
                ['HEADING\n', 'third\n', 'same\n', 'fourth\n', 'FOOTER\n'])
            second.close()

            diff = unified_diff(first.name, second.name, as_string=True)

            self.assertEqual(
                diff, ''.join([
                    '--- {}\n'.format(first.name),
                    '+++ {}\n'.format(second.name),
                    '@@ -1,5 +1,5 @@\n',
                    ' HEADING\n',
                    '-first\n',
                    '+third\n',
                    ' same\n',
                    '-second\n',
                    '+fourth\n',
                    ' FOOTER\n',
                ]))
        finally:
            os.remove(first.name)
            os.remove(second.name)
Beispiel #6
0
def get_ap(inds,
           dists,
           query_name,
           index_names,
           groundtruth_dir,
           ranked_dir=None):
    if ranked_dir is not None:
        # Create dir for ranked results if needed
        if not os.path.exists(ranked_dir):
            os.makedirs(ranked_dir)
        rank_file = os.path.join(ranked_dir, '%s.txt' % query_name)
        f = open(rank_file, 'w')
    else:
        f = NamedTemporaryFile(delete=False)
        rank_file = f.name

    f.writelines([index_names[i] + '\n' for i in inds])
    f.close()

    groundtruth_prefix = os.path.join(groundtruth_dir, query_name)
    cmd = './compute_ap %s %s' % (groundtruth_prefix, rank_file)
    ap = os.popen(cmd).read()

    # Delete temp file
    if ranked_dir is None:
        os.remove(rank_file)

    return float(ap.strip())
class TestTrio(object):
    """Test class for testing how the individual class behave"""
    
    def setup_class(self):
        """Setup a standard trio."""
        trio_lines = ['#Standard trio\n', 
                    '#FamilyID\tSampleID\tFather\tMother\tSex\tPhenotype\n', 
                    'healthyParentsAffectedSon\tproband\tfather\tmother\t1\t2\n',
                    'healthyParentsAffectedSon\tmother\t0\t0\t2\t1\n', 
                    'healthyParentsAffectedSon\tfather\t0\t0\t1\t1\n'
                    ]
        self.trio_file = NamedTemporaryFile(mode='w+t', delete=False, suffix='.vcf')
        self.trio_file.writelines(trio_lines)
        self.trio_file.seek(0)
        self.trio_file.close()
        
    
    def test_standard_trio(self):
        """Test if the file is parsed in a correct way."""
        family_parser = parser.FamilyParser(open(self.trio_file.name, 'r'))
        assert family_parser.header == [
                                    'family_id', 
                                    'sample_id', 
                                    'father_id', 
                                    'mother_id', 
                                    'sex', 
                                    'phenotype'
                                    ]
        assert 'healthyParentsAffectedSon' in family_parser.families
        assert set(['proband', 'mother', 'father']) == set(family_parser.families['healthyParentsAffectedSon'].individuals.keys())
        assert set(['proband', 'mother', 'father']) == set(family_parser.families['healthyParentsAffectedSon'].trios[0])
Beispiel #8
0
def set_position(path_to_nis,
                 pos_xy=None,
                 pos_z=None,
                 pos_piezo=None,
                 relative_xy=False,
                 relative_z=False,
                 relative_piezo=False):

    # nothing to do
    if pos_xy is None and pos_z is None and pos_piezo is None:
        return

    cmd = []
    if not pos_xy is None:
        cmd.append('StgMoveXY({},{},{});'.format(pos_xy[0], pos_xy[1],
                                                 1 if relative_xy else 0))
    if not pos_z is None:
        cmd.append('StgMoveMainZ({},{});'.format(pos_z,
                                                 1 if relative_z else 0))
    if not pos_piezo is None:
        cmd.append('StgMovePiezoZ({},{});'.format(pos_z,
                                                  1 if relative_piezo else 0))

    try:
        ntf = NamedTemporaryFile(suffix='.mac', delete=False)
        ntf.writelines([bytes('\n'.join(cmd), 'utf-8')])
        ntf.close()
        subprocess.call(' '.join([quote(path_to_nis), '-mw', quote(ntf.name)]))
    finally:
        os.remove(ntf.name)
Beispiel #9
0
def load_button_pixbufs(color):
    global BUTTONS_SVG

    if BUTTONS_SVG is None:
        image_path = os.path.join(MODULE_DIR, 'images', 'mouse.svg')
        with open(image_path) as svg_file:
            BUTTONS_SVG = svg_file.readlines()

    if not isinstance(color, str):
        # Gdk.Color
        color = 'rgb({}, {}, {})'.format(round(color.red_float * 255),
                                         round(color.green_float * 255),
                                         round(color.blue_float * 255))
    button_pixbufs = []
    svg = NamedTemporaryFile(mode='w', suffix='.svg')
    for line in BUTTONS_SVG[1:-1]:
        svg.seek(0)
        svg.truncate()
        svg.writelines((
            BUTTONS_SVG[0],
            line.replace('#fff', color),
            BUTTONS_SVG[-1],
        ))
        svg.flush()
        os.fsync(svg.fileno())
        button_pixbufs.append(GdkPixbuf.Pixbuf.new_from_file(svg.name))
    svg.close()
    return button_pixbufs
Beispiel #10
0
def test_pattern_patching():
    of = NamedTemporaryFile('wt')
    of.writelines([
        'one line\n', 'this pattern will be patched: defbbahij\n',
        'third line\n', 'another pattern: jihaabfed'
    ])
    of.flush()

    files = FileList()
    f = FileInfo(files, of.name)
    f.load()
    f.scan_for_matches()
    matches = f.matches_of_type(BasicPattern)
    assert len(matches) == 2
    p2 = matches[1]

    # manually add patch, to see if .append() works:
    f.patches.append(p2.append('XXX'))

    # apply all patches:
    f.gen_patches()
    patched = f.get_patched_content()
    assert patched == ('one line\n' +
                       'this pattern will be patched: defBBBBBhij\n' +
                       'third line\n' + 'another pattern: jihAAAAAXXXfed')
Beispiel #11
0
    async def run_this(self, opsdroid, config, message):
        languageCode = message.regex.group('lang')
        language = next((lang for lang in config["containers"]
                         if languageCode in lang["language"]), None)

        if (language == None):
            await message.respond(f"Sorry, I don't know {languageCode}")
            return

        container = language["container"]
        code = message.regex.group('code')

        if (code == None or len(code) == 0):
            await message.respond(f"Wait, run what?")
            return

        await message.respond(
            f"<p>Let me try that in <code>{container}</code></p>")

        # This requires you to have a mounted volume
        codefile = NamedTemporaryFile(mode='w+t',
                                      suffix=language["extension"],
                                      dir=config["workdir"],
                                      delete=False)
        if (language["extension"] == '.ps1'):
            codefile.writelines("$ProgressPreference='SilentlyContinue'\n")
        codefile.writelines(code)
        codefile.close()
        volume, workdir = config["volume"].split(":")
        filename = "{}/{}".format(workdir, os.path.split(codefile.name)[1])
        await self.invoke_docker(message.respond, language["container"],
                                 config["volume"], language["command"],
                                 filename)
Beispiel #12
0
class TestTrio(object):
    """Test class for testing how the individual class behave"""
    def setup_class(self):
        """Setup a standard trio."""
        trio_lines = [
            '#Standard trio\n',
            '#FamilyID\tSampleID\tFather\tMother\tSex\tPhenotype\n',
            'healthyParentsAffectedSon\tproband\tfather\tmother\t1\t2\n',
            'healthyParentsAffectedSon\tmother\t0\t0\t2\t1\n',
            'healthyParentsAffectedSon\tfather\t0\t0\t1\t1\n',
            'healthyParentsAffectedSon\tdaughter\tfather\tmother\t2\t1\n',
        ]
        self.trio_file = NamedTemporaryFile(mode='w+t',
                                            delete=False,
                                            suffix='.vcf')
        self.trio_file.writelines(trio_lines)
        self.trio_file.seek(0)
        self.trio_file.close()

    def test_standard_trio_extra_daughter(self):
        """Test if the file is parsed in a correct way."""
        family_parser = FamilyParser(open(self.trio_file.name, 'r'))
        trio_family = family_parser.families['healthyParentsAffectedSon']

        assert family_parser.header == [
            'family_id', 'sample_id', 'father_id', 'mother_id', 'sex',
            'phenotype'
        ]
        assert set(['proband', 'mother', 'father', 'daughter']) == set(
            family_parser.families['healthyParentsAffectedSon'].individuals.
            keys())
        assert set(['proband', 'mother', 'father']) in trio_family.trios
        assert set(['daughter', 'mother', 'father']) in trio_family.trios
        assert 'daughter' in trio_family.individuals['proband'].siblings
    def test_build_from_texts(self):
        sample = 'chi ka nu chi nu kachika kanu unka'
        file_even_dist = NamedTemporaryFile(mode='w', delete=False)
        file_even_dist.writelines(sample*200)
        file_even_dist.close()

        sample = 'uka'
        file_odd_dist = NamedTemporaryFile(mode='w', delete=False)
        file_odd_dist.writelines(sample * 600)
        file_odd_dist.close()

        bld = LangVectorDistributionBuilder()
        distr = bld.build_files_reference_distribution([file_even_dist.name, file_odd_dist.name])
        os.unlink(file_even_dist.name)
        os.unlink(file_odd_dist.name)

        self.assertIsNotNone(distr)

        # test texts
        calc = CosineSimilarityOcrRatingCalculator()
        calc.distribution_by_lang['chi'] = distr

        sample_text = 'Chika ka nuunchi chi ka' * 5000
        grade = calc.get_rating(sample_text, 'chi')
        self.assertGreater(grade, 5)
Beispiel #14
0
    def is_installed(cls, version):
        from tempfile import NamedTemporaryFile

        ret = True
        try:
            temp = NamedTemporaryFile(suffix='.g')
            temp.writelines([b'echo "version: "{version} \n', b'quit \n'])
            temp.seek(0)
            out = sp.check_output(
                ['genesis', '-nox', '-batch', '-notty', temp.name])
            m = re.search(b'version:\s*([0-9]*\.?[0-9]+)\s*', out)
            if m:
                ver = m.groups()[0]

                if isinstance(ver, bytes):
                    ver = ver.decode('utf-8')
                ret = 'v%s' % ver
                inform("Found GENESIS in path, version %s" % ret,
                       verbosity=1,
                       indent=2)
        except OSError:
            ret = False
        finally:
            temp.close()

        return ret
Beispiel #15
0
    def cargo_toml_context():
        tmp_file = NamedTemporaryFile(buffering=False)
        with open("Cargo.toml", "rb") as f:
            tmp_file.writelines(f.readlines())

        cargo_file = toml.load("Cargo.toml")

        cargo_file.setdefault("patch",
                              {}).setdefault("crates-io", {})["jsonschema"] = {
                                  "path":
                                  os.environ["UNRELEASED_JSONSCHEMA_PATH"],
                              }

        with open("Cargo.toml", "w") as f:
            toml.dump(cargo_file, f)

        try:
            print(
                "Modified Cargo.toml file by patching jsonschema dependency to {}"
                .format(os.environ["UNRELEASED_JSONSCHEMA_PATH"]),
                file=sys.stderr,
            )
            yield
        except:
            print("Cargo.toml used during the build", file=sys.stderr)
            with open("Cargo.toml", "r") as f:
                print(f.read(), file=sys.stderr)

            raise
        finally:
            with open("Cargo.toml", "wb") as f:
                tmp_file.seek(0)
                f.writelines(tmp_file.readlines())
def write_xar(fn, hdr, tocdata, heap, keep_old=False):
    ztocdata = zlib.compress(tocdata)
    digest = toc_digest(hdr, ztocdata)
    newhdr = dict(hdr,
                  toc_length_uncompressed=len(tocdata),
                  toc_length_compressed=len(ztocdata))
    outf = NamedTemporaryFile(prefix='.' + os.path.basename(fn),
                              dir=os.path.dirname(fn),
                              delete=False)
    try:
        st_mode = os.stat(fn).st_mode
        if os.fstat(outf.fileno()) != st_mode:
            os.fchmod(outf.fileno(), st_mode)
    except OSError:
        pass
    try:
        outf.writelines([HEADER.pack(newhdr),
                         ztocdata,
                         digest])
        copyfileobj(heap, outf)
        outf.close()
    except:
        outf.close()
        os.unlink(outf.name)
        raise
    if keep_old:
        oldfn = fn + '.old'
        if os.path.exists(oldfn):
            os.unlink(oldfn)
        os.link(fn, oldfn)
    os.rename(outf.name, fn)
Beispiel #17
0
class JobOptionsCmd(JobOptions):
    def __init__(self, cmds=[]):

        # massaging of input variables
        if isinstance(cmds, str):
            cmds = [cmds]
            pass
        if not isinstance(cmds, list):
            cmds = [cmds]
            pass

        JobOptions.__init__(self, fileName=None)
        self.cmds = cmds
        self.tmpFile = NamedTemporaryFile(suffix=".py", mode='w+')
        return

    def name(self):

        # wipeout any previous content of the file
        self.tmpFile.file.truncate(0)

        # go at the beginning of the file to write our commands
        self.tmpFile.file.seek(0)

        for l in self.cmds:
            self.tmpFile.writelines(l + os.linesep)
            pass

        # commit our changes
        self.tmpFile.file.flush()

        return self.tmpFile.name

    pass  # JobOptionsCmd
Beispiel #18
0
def write_xar(fn, hdr, tocdata, heap, keep_old=False):
    ztocdata = zlib.compress(tocdata)
    digest = toc_digest(hdr, ztocdata)
    newhdr = dict(hdr,
                  toc_length_uncompressed=len(tocdata),
                  toc_length_compressed=len(ztocdata))
    outf = NamedTemporaryFile(prefix='.' + os.path.basename(fn),
                              dir=os.path.dirname(fn),
                              delete=False)
    try:
        st_mode = os.stat(fn).st_mode
        if os.fstat(outf.fileno()) != st_mode:
            os.fchmod(outf.fileno(), st_mode)
    except OSError:
        pass
    try:
        outf.writelines([HEADER.pack(newhdr), ztocdata, digest])
        copyfileobj(heap, outf)
        outf.close()
    except:
        outf.close()
        os.unlink(outf.name)
        raise
    if keep_old:
        oldfn = fn + '.old'
        if os.path.exists(oldfn):
            os.unlink(oldfn)
        os.link(fn, oldfn)
    os.rename(outf.name, fn)
Beispiel #19
0
def do_large_image_scan(path_to_nis,
                        save_path,
                        left,
                        right,
                        top,
                        bottom,
                        overlap=0,
                        registration=False,
                        z_count=1,
                        z_step=1.5,
                        close=True):
    try:
        ntf = NamedTemporaryFile(suffix='.mac', delete=False)
        cmd = '''
        Stg_SetLargeImageStageZParams({}, {}, {});
        Stg_LargeImageScanArea({},{},{},{},0,{},0,{},0,"{}");
        {}
        '''.format(0 if (z_count <= 1) else 1, z_step, z_count, left, right,
                   top, bottom, overlap, 1 if registration else 0, save_path,
                   'CloseCurrentDocument(2);' if close else '')

        ntf.writelines([bytes(cmd, 'utf-8')])

        ntf.close()
        subprocess.call(' '.join([quote(path_to_nis), '-mw', quote(ntf.name)]))
    finally:
        os.remove(ntf.name)
Beispiel #20
0
def as_script(func):
    # create source code
    code_str = inspect.getsource(func)
    src_lines = code_str.split('\n')

    # remove function signature
    i = 0
    while f"def {func.__name__}(" not in src_lines[i]:
        i += 1
    src_lines = src_lines[i + 1:]

    src_lines = [line + '\n' for line in src_lines]

    # add exception relay machinery
    src_lines.insert(0, "import sys\n")
    src_lines.insert(1, "import traceback\n")
    src_lines.insert(2, "try:\n")
    src_lines.append("\nexcept Exception as e:\n"
                     "    sys.stderr.write(traceback.format_exc())\n")

    # write file
    source_file = NamedTemporaryFile(prefix=f"{func.__name__}-",
                                     suffix=".py",
                                     mode="w+",
                                     delete=False)
    source_file.writelines(src_lines)
    source_path = source_file.name
    func.source = ''.join(src_lines)

    def direct_call(*args, **kwargs):
        direct_call(*args, **kwargs)

    direct_call.path = source_path

    return direct_call
def setup_vcf_file():
    """
    Print some variants to a vcf file and return the filename
    """
    vcf_lines = [
        '##fileformat=VCFv4.1\n',
        '##INFO=<ID=MQ,Number=1,Type=Float,Description="RMS Mapping Quality">\n',
        '##contig=<ID=1,length=249250621,assembly=b37>\n',
        '##reference=file:///humgen/gsa-hpprojects/GATK/bundle'\
        '/current/b37/human_g1k_v37.fasta\n',
        '#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tfather\tmother\tproband\n',
        '1\t11900\t.\tA\tT\t100\tPASS\tMQ=1\tGT:GQ\t0/1:60\t0/1:60\t1/1:60\n',
        '1\t879585\t.\tA\tT\t100\tPASS\tMQ=1\tGT:GQ\t0/1:60\t0/0:60\t0/1:60\n',
        '1\t879586\t.\tA\tT\t100\tPASS\tMQ=1\tGT:GQ\t0/0:60\t0/1:60\t0/1:60\n',
        '1\t947378\t.\tA\tT\t100\tPASS\tMQ=1\tGT:GQ\t0/0:60\t0/0:60\t0/1:60\n',
        '1\t973348\t.\tG\tA\t100\tPASS\tMQ=1\tGT:GQ\t0/0:60\t0/0:60\t0/1:60\n',
        '3\t879585\t.\tA\tT\t100\tPASS\tMQ=1\tGT:GQ\t0/1:60\t0/0:60\t0/1:60\n',
        '3\t879586\t.\tA\tT\t100\tPASS\tMQ=1\tGT:GQ\t0/0:60\t0/1:60\t0/1:60\n',
        '3\t947378\t.\tA\tT\t100\tPASS\tMQ=1\tGT:GQ\t0/0:60\t0/0:60\t0/1:60\n',
        '3\t973348\t.\tG\tA\t100\tPASS\tMQ=1\tGT:GQ\t0/0:60\t0/0:60\t0/1:60\n'
        ]
    vcf_file = NamedTemporaryFile(mode='w+t', delete=False, suffix='.vcf')
    vcf_file.writelines(vcf_lines)
    vcf_file.seek(0)
    vcf_file.close()
    
    return vcf_file.name
Beispiel #22
0
    def track(self,
              img_files,
              box,
              show_progress=False,
              output_video=None,
              output_video_fps=30,
              visualize_subsample=1,
              visualize_threshold=0.1,
              return_masks=False,
              **tracker_args):
        x0, y0, x1, y1 = box
        w = x1 - x0
        h = y1 - y0
        region = [x0, y0, w, h]
        images_list = NamedTemporaryFile('w')
        images_list.writelines([f'{x}\n' for x in img_files])
        images_list.seek(0)
        # print(images_list.name)
        # print('hi')
        # subprocess.run(['cat', images_list.name], stderr=subprocess.STDOUT)
        # print('hi')
        # print([f'{x}\n' for x in img_files][:5])
        # print('hi')
        # print(img_files)
        # print('hi')

        output = NamedTemporaryFile('w', suffix='.mat')
        command = [
            'matlab', '-r',
            f"run_SRDCF_TAO('{images_list.name}', {region}, '{output.name}'); "
            f"quit"
        ]
        # Conda is clashing with MATLAB here, causing an error in C++ ABIs.
        # Unsetting LD_LIBRARY_PATH fixes this.
        env = os.environ.copy()
        env['LD_LIBRARY_PATH'] = ''
        try:
            subprocess.check_output(command,
                                    stderr=subprocess.STDOUT,
                                    cwd=str(SRDCF_ROOT),
                                    env=env)
        except subprocess.CalledProcessError as e:
            logging.fatal('Failed command.\nException: %s\nOutput %s',
                          e.returncode, e.output.decode('utf-8'))
            raise

        result = loadmat(output.name)['results'].squeeze()
        images_list.close()
        output.close()

        boxes = result['res'].item()
        # width, height -> x1, y1
        boxes[:, 2] += boxes[:, 0]
        boxes[:, 3] += boxes[:, 1]
        # scores = result['scores'].item()
        scores = np.ones((boxes.shape[0], 1))
        scores[0] = float('inf')
        boxes = np.hstack((boxes, scores))
        return boxes, None, None
class Test3dBuild(TestCase):
    """
    Tests for the build method with a 3d mesh.
    """
    def setUp(self):
        """
        Create a function builder and input file.
        """
        self.mesh = UnitCubeMesh(10, 10, 10)
        self.V = FunctionSpace(self.mesh, 'CG', 1)
        self.fb = FileBuilder(self.mesh, self.V)

        self.input = NamedTemporaryFile(mode='w+')
        self.fb.assign('path', self.input.name)

    def test_3d_constant(self):
        """
        Test interpolation is correct for a 3d constant valued file.
        """
        lines = ['#   x,   y,   z,    v\n']
        lines += [
            f'{x}, {y}, {z}, 10.0\n' for x in np.linspace(0.0, 1.0, 11)
            for y in np.linspace(0.0, 1.0, 11)
            for z in np.linspace(0.0, 1.0, 11)
        ]

        self.input.writelines(lines)
        self.input.flush()

        f = self.fb.build()

        self.assertAlmostEqual(f((0.0, 0.0, 0.0)), 10.0)
        self.assertAlmostEqual(f((0.0, 0.15, 0.94)), 10.0)
        self.assertAlmostEqual(f((0.38, 0.62, 0.01)), 10.0)
        self.assertAlmostEqual(f((0.99, 0.99, 0.54)), 10.0)

    def test_3d_linear(self):
        """
        Test interpolation is correct for a 3d linear valued file.
        """
        lines = ['#   x,   y,   z,    v\n']
        lines += [
            f'{x}, {y}, {z}, {10*(x+y+z)}\n'
            for x in np.linspace(0.0, 1.0, 11)
            for y in np.linspace(0.0, 1.0, 11)
            for z in np.linspace(0.0, 1.0, 11)
        ]

        self.input.writelines(lines)
        self.input.flush()

        f = self.fb.build()

        self.assertAlmostEqual(f((0.0, 0.0, 0.0)), 0.0)
        self.assertAlmostEqual(f((0.0, 0.15, 0.94)), 10.9)
        self.assertAlmostEqual(f((0.38, 0.62, 0.01)), 10.1)
        self.assertAlmostEqual(f((0.99, 0.99, 0.54)), 25.2)
Beispiel #24
0
def set_optical_configuration(path_to_nis, oc_name):
    try:
        ntf = NamedTemporaryFile(suffix='.mac', delete=False)
        cmd = 'SelectOptConf("{0}");'.format(*[oc_name])
        ntf.writelines([bytes(cmd, 'utf-8')])

        ntf.close()
        subprocess.call(' '.join([quote(path_to_nis), '-mw', quote(ntf.name)]))
    finally:
        os.remove(ntf.name)
Beispiel #25
0
def create_template(contents):
    """
    Generate a temporary file with the specified contents as a list of strings
    and yield its path as the context.
    """
    template = NamedTemporaryFile(mode="w", prefix="certtool-template")
    template.writelines(map(lambda l: l + "\n", contents))
    template.flush()
    yield template.name
    template.close()
Beispiel #26
0
def daemon(acl, addr='localhost'):
    """ Create an Rbldnsd instance with given ACL
    """
    acl_zone = NamedTemporaryFile()
    acl_zone.writelines("%s\n" % line for line in acl)
    acl_zone.flush()

    dnsd = Rbldnsd(daemon_addr=addr)
    dnsd.add_dataset('acl', acl_zone)
    dnsd.add_dataset('generic', ZoneFile(['test TXT "Success"']))
    return dnsd
Beispiel #27
0
def daemon(acl, addr='localhost'):
    """ Create an Rbldnsd instance with given ACL
    """
    acl_zone = NamedTemporaryFile()
    acl_zone.writelines("%s\n" % line for line in acl)
    acl_zone.flush()

    dnsd = Rbldnsd(daemon_addr=addr)
    dnsd.add_dataset('acl', acl_zone)
    dnsd.add_dataset('generic', ZoneFile(['test TXT "Success"']))
    return dnsd
Beispiel #28
0
def backup_optical_configurations(path_to_nis, backup_path):
    """
    export all optical configurations as XML
    """
    try:
        ntf = NamedTemporaryFile(suffix='.mac', delete=False)
        cmd = '''BackupOptConf("{}");'''.format(backup_path)
        ntf.writelines([bytes(cmd, 'utf-8')])
        ntf.close()
        subprocess.call(' '.join([quote(path_to_nis), '-mw', quote(ntf.name)]))
    finally:
        os.remove(ntf.name)
Beispiel #29
0
def concat_videos(videos: List[Path], connect_dir: Path) -> Path:
    temp = NamedTemporaryFile("w")
    temp.writelines([f"file '{str(p)}'\n" for p in videos])
    temp.flush()

    concat_output = connect_dir / "all_videos.mkv"
    if not concat_output.exists():
        subprocess.run(
            ["ffmpeg", "-f", "concat", "-safe", "0", "-i", temp.name, "-c", "copy", str(concat_output.absolute())]
        )

    return concat_output
def modify_props(device, local_prop_file, target_prop_file, new_props):
    """To change the props if need
    Args:
        device: the device to modify props
        local_prop_file : the local file to save the old props
        target_prop_file : the target prop file to change
        new_props  : the new props
    Returns:
        True : prop file changed
        False : prop file no need to change
    """
    is_changed = False
    device.pull_file(target_prop_file, local_prop_file)
    old_props = {}
    changed_prop_key = []
    lines = []
    with open(local_prop_file, 'r') as old_file:
        lines = old_file.readlines()
        if lines:
            lines[-1] = lines[-1] + '\n'
        for line in lines:
            line = line.strip()
            if not line.startswith("#") and line.find("=") > 0:
                key_value = line.split("=")
                if len(key_value) == 2:
                    old_props[line.split("=")[0]] = line.split("=")[1]

    for key, value in new_props.items():
        if key not in old_props.keys():
            lines.append("".join([key, "=", value, '\n']))
            is_changed = True
        elif old_props.get(key) != value:
            changed_prop_key.append(key)
            is_changed = True

    if is_changed:
        local_temp_prop_file = NamedTemporaryFile(mode='w',
                                                  prefix='build',
                                                  suffix='.tmp',
                                                  delete=False)
        for index, line in enumerate(lines):
            if not line.startswith("#") and line.find("=") > 0:
                key = line.split("=")[0]
                if key in changed_prop_key:
                    lines[index] = "".join([key, "=", new_props[key], '\n'])
        local_temp_prop_file.writelines(lines)
        local_temp_prop_file.close()
        device.push_file(local_temp_prop_file.name, target_prop_file)
        device.execute_shell_command(" ".join(["chmod 644", target_prop_file]))
        LOG.info("Changed the system property as required successfully")
        os.remove(local_temp_prop_file.name)

    return is_changed
Beispiel #31
0
def fixture_temp_config_file():
    temp_file = NamedTemporaryFile(mode="w")
    temp_file.writelines([
        "version: 1\n",
        "changelog:\n",
        "  default_title: 'Uncategorized'\n",
        "  folder: changelogs\n",
        "  start: master\n",
        "  end: HEAD\n",
    ])
    temp_file.seek(0)
    return temp_file
Beispiel #32
0
def get_ap(inds,
           query_name,
           index_names,
           groundtruth_dir,
           ranked_dir=None,
           disp_each=True):
    """
    Given a query, index data, and path to groundtruth data, perform the query,
    and evaluate average precision for the results by calling to the compute_ap
    script. Optionally save ranked results in a file.

    :param ndarray inds:
        the indices of index vectors in ascending order of distance
    :param str query_name:
        the name of the query
    :param list index_names:
        the name of index items
    :param str groundtruth_dir:
        directory of groundtruth files
    :param str ranked_dir:
        optional path to a directory to save ranked list for query

    :returns float:
        the average precision for this query
    """

    if ranked_dir is not None:
        # Create dir for ranked results if needed
        if not os.path.exists(ranked_dir):
            os.makedirs(ranked_dir)
        rank_file = os.path.join(ranked_dir, '%s.txt' % query_name)
        f = open(rank_file, 'w')
    else:
        f = NamedTemporaryFile(delete=False)
        rank_file = f.name

    f.writelines([index_names[i] + '\n' for i in inds])
    f.close()

    groundtruth_prefix = os.path.join(groundtruth_dir, query_name)
    oxford_benchmark_prog = os.path.join('cpp', 'compute_ap')
    cmd = ' %s %s' % (groundtruth_prefix, rank_file)
    cmd = oxford_benchmark_prog + cmd
    ap = os.popen(cmd).read()

    # Delete temp file
    if ranked_dir is None:
        os.remove(rank_file)

    if disp_each:
        print(Notify.UNDERLINE, query_name, float(ap.strip()), Notify.ENDC)
    return float(ap.strip())
Beispiel #33
0
    def run(self, path_to_nis):
        try:
            ntf = NamedTemporaryFile(suffix='.mac', delete=False)

            # run cmd
            ntf.writelines([bytes('ND_RunExperiment(0);', 'utf-8')])

            ntf.close()
            subprocess.call(' '.join(
                [quote(path_to_nis), '-mw',
                 quote(ntf.name)]))
        finally:
            os.remove(ntf.name)
Beispiel #34
0
def layout_graph(filename):
    out = NamedTemporaryFile(mode="w", delete=False)
    out.writelines(_convert_dot_file(filename))
    out.close()  # flushes the cache
    cmd = []
    cmd.append("dot")
    cmd.append("-Kfdp")
    cmd.append("-Tpdf")
    cmd.append("-Gcharset=utf-8")
    cmd.append("-o{0}.pdf".format(os.path.splitext(filename)[0]))
    cmd.append(out.name)
    execute_command(cmd)
    # Manually remove the temporary file
    os.unlink(out.name)
Beispiel #35
0
def layout_graph(filename):
    out = NamedTemporaryFile(mode="w", delete=False)
    out.writelines(_convert_dot_file(filename))
    out.close() # flushes the cache
    cmd = []
    cmd.append("dot")
    cmd.append("-Kfdp")
    cmd.append("-Tpdf")
    cmd.append("-Gcharset=utf-8")
    cmd.append("-o{0}.pdf".format(os.path.splitext(filename)[0]))
    cmd.append(out.name)
    execute_command(cmd)
    # Manually remove the temporary file
    os.unlink(out.name)
def TemporaryGenomeFile(genome_name):
    """
    returns a file-like object pointing to a temporary file containing
    the chromsome names and sizes

    the current file position will be 0

    it will be deleted when the object is garbage collected
    """
    f = NamedTemporaryFile()
    genome_rows = genome(genome_name).iteritems()
    f.writelines(('%s\t%d\n' for chrom, size in genome_rows))
    f.flush()
    f.seek(0)
    return f
Beispiel #37
0
def get_vcf_file(vcf_lines):
    """
    Take an iterator with vcf lines and prints them to a temporary file.
    
    Arguments:
        vcf_lines (iterator): An iterator with vcf lines
    
    Returns:
        filename (str): The path to the vcf file
    """
    vcf_file = NamedTemporaryFile(mode='w+t', delete=False, suffix='.vcf')
    vcf_file.writelines(vcf_lines)
    vcf_file.seek(0)
    vcf_file.close()
    
    return vcf_file.name
Beispiel #38
0
def get_ap(Q, data, query_name, index_names, groundtruth_dir, ranked_dir=None):
    """
    Given a query, index data, and path to groundtruth data, perform the query,
    and evaluate average precision for the results by calling to the compute_ap
    script. Optionally save ranked results in a file.

    :param ndarray Q:
        query vector
    :param ndarray data:
        index data vectors
    :param str query_name:
        the name of the query
    :param list index_names:
        the name of index items
    :param str groundtruth_dir:
        directory of groundtruth files
    :param str ranked_dir:
        optional path to a directory to save ranked list for query

    :returns float:
        the average precision for this query
    """
    inds, dists = get_nn(Q, data)

    if ranked_dir is not None:
        # Create dir for ranked results if needed
        if not os.path.exists(ranked_dir):
            os.makedirs(ranked_dir)
        rank_file = os.path.join(ranked_dir, '%s.txt' % query_name)
        f = open(rank_file, 'w')
    else:
        f = NamedTemporaryFile(delete=False)
        rank_file = f.name

    f.writelines([index_names[i] + '\n' for i in inds])
    f.close()

    groundtruth_prefix = os.path.join(groundtruth_dir, query_name)
    cmd = './compute_ap %s %s' % (groundtruth_prefix, rank_file)
    ap = os.popen(cmd).read()

    # Delete temp file
    if ranked_dir is None:
        os.remove(rank_file)

    return float(ap.strip())
 def get_earth_model(self, model):
     """
     Check whether the specified Earth density profile has a correct
     NuCraft preface. If not, create a temporary file that does.
     """
     logging.debug('Trying to construct Earth model from "%s"' % model)
     try:
         resource_path = find_resource(model)
         self.earth_model = EarthModel(resource_path)
         logging.info("Loaded Earth model from %s" % model)
     except SyntaxError:
         # Probably the file is lacking the correct preamble
         logging.info(
             "Failed to construct NuCraft Earth model directly from"
             " %s! Adding default preamble..." % resource_path
         )
         # Generate tempfile with preamble
         with open(resource_path, "r") as infile:
             profile_lines = infile.readlines()
         preamble = [
             "# nuCraft Earth model with PREM density "
             "values for use as template; keep structure "
             "of the first six lines unmodified!\n",
             "(0.4656,0.4656,0.4957)   # tuple of (relative) "
             #'(0.5, 0.5, 0.5)   # tuple of (relative) '
             "electron numbers for mantle, outer core, " "and inner core\n",
             "6371.    # radius of the Earth\n",
             "3480.    # radius of the outer core\n",
             "1121.5   # radius of the inner core\n",
             "# two-columned list of radii and corresponding "
             "matter density values in km and kg/dm^3; "
             "add, remove or modify lines as necessary\n",
         ]
         tfile = NamedTemporaryFile()
         tfile.writelines(preamble + profile_lines)
         tfile.flush()
         try:
             self.earth_model = EarthModel(tfile.name)
         except:
             logging.error("Could not construct Earth model from %s: %s" % (model, sys.exc_info()[1]))
             sys.exit(1)
         logging.info("Successfully constructed Earth model")
         tfile.close()
     except IOError:
         logging.info('Using NuCraft built-in Earth model "%s"' % model)
         self.earth_model = EarthModel(model)
    def __setup(self, inputlines):
        """
        Create a temporary input file and a temporary output-file.
        """
        # create output file fist, so it is older
        outputfile = NamedTemporaryFile("w", suffix='.json', delete=False)
        outputfile.write('--- empty marker ---')
        outputfile.close()
        self.output_filename = outputfile.name

        time.sleep(1) # ensure a time-difference between files

        inputfile = NamedTemporaryFile("w", suffix='.txt', delete=False)
        for line in inputlines:
            inputfile.writelines((line, '\n'))
        inputfile.close()
        self.input_filename = inputfile.name
Beispiel #41
0
def main():
    argparser = argparse.ArgumentParser(description="Call denovo variants on a VCF file containing a trio")

    argparser.add_argument('--denovogear', 
            type=str, nargs=1,
            required=True,
            help='Path to the denovogear binary for example: /Users/timothyh/home/binHTS/denovogear 0.5.4/bin/denovogear'
        )
        
    argparser.add_argument('vcf_file',
        type=str, nargs=1,
        help='A variant file in VCF format containing the genotypes of the trio'
    )
    
    argparser.add_argument('ped_file',
        type=str, nargs=1,
        help='A pedigree file in .ped format containing the samples to be extracted from the VCF file'
    )
    
    
    print("Hello")
    # Check that the PED file only contains a trio and identify sex of child as this is required for calling denovogear correctly

    # Check that the VCF file contains the same samples as thePED file

    # Run denovogear

    #fam_id = '1'
    #someFamily = family.Family(family_id = fam_id)
    #print(someFamily)
    trio_lines = ['#Standard trio\n', 
                    '#FamilyID\tSampleID\tFather\tMother\tSex\tPhenotype\n', 
                    'healthyParentsAffectedSon\tproband\tfather\tmother\t1\t2\n',
                    'healthyParentsAffectedSon\tmother\t0\t0\t2\t1\n', 
                    'healthyParentsAffectedSon\tfather\t0\t0\t1\t1\n',
                    'healthyParentsAffectedSon\tdaughter\tfather\tmother\t2\t1\n',
                    ]
    trio_file = NamedTemporaryFile(mode='w+t', delete=False, suffix='.vcf')
    trio_file.writelines(trio_lines)
    trio_file.seek(0)
    trio_file.close()

    family_parser = parser.FamilyParser(trio_file.name)
    trio_family = family_parser.families['healthyParentsAffectedSon']
    print(trio_family.)
Beispiel #42
0
def opengzip(fn,cb):
    from tempfile import NamedTemporaryFile
    from gzip import GzipFile
    gzip = GzipFile(fn,'rb')
    try:
        tmp = NamedTemporaryFile('w+b')
        try:
            tmp.writelines(gzip)
            tmp.flush()
            cb(tmp.name)
        finally:
            tmp.close()
    except IOError:
        cb(fn)
    except KeyboardInterrupt:
        raise
    finally:
        gzip.close()
Beispiel #43
0
 def create_job(self, commands, name=None, walltime=None, restartable=False, stagein_files=None):
     if not commands: raise ValueError('SerialJobFactory.create_job: commands '
                                       'should be a non-empty list of strings')
     if not name: name = self._name
     else: name = self._time_name(name)
     job_file = NamedTemporaryFile(prefix=name+'-',
                                   suffix='.job',
                                   delete=False)
     #name, stdin/stdout, nodes, etc...
     job_file.writelines(['#!/bin/bash\n',
                          '#PBS -N %s\n'     % name,
                          '#PBS -o %s.out\n' % name,
                          '#PBS -e %s.err\n' % name,
                          '#PBS -c enabled,shutdown,periodic\n',
                          '#PBS -r %s\n' %('y' if restartable else 'n'),
                          '#PBS -l nodes=%s\n' % self._host])
     #optional
     if walltime:
         job_file.write('#PBS -l walltime=%s\n' % walltime)
     if self._email:
         job_file.write('#PBS -m abe -M %s\n' % self._email)
     #stageins
     if stagein_files:
         stagein_files = self._prepare_stagein(stagein_files)
         for f in stagein_files.values():
             job_file.write('#PBS -W stagein="$TMPDIR/%(basename)s@%(runner)s:%(path)s"\n'
                            % {'basename': f,
                               'runner': self._hostname,
                               'path': os.path.join(self._tmpdir, f)})
     job_file.write('\n')
     #write commands
     job_file.write('cd $TMPDIR\n')
     for cmd in commands:
         try: cmd = cmd % stagein_files
         except TypeError: pass
         job_file.write(cmd+'\n')
     #copy all output files back to the working directory
     job_file.write(('if [[ "%(runner)s" == $(hostname) ]];\n'
                     'then  cp -r * %(wdir)s/;\n'
                     'else scp -r * %(runner)s:%(wdir)s/;\n'
                     'fi') % {'runner': self._hostname,
                              'wdir':   os.getcwd().replace(' ', '\ ')})
     job_file.close()
     return job_file.name
Beispiel #44
0
    def get_receiver(self, sensor):
        results = list(self.sensors.aggregate([ { "$match": { "SERVER": sensor }},{ "$project": { "ip": "$receiver", "port": "$receiver_port", "STATUS": "$STATUS" }}]))
        if len(results) == 0:
            return 'Cannot find sensor', '', ''

        elif results[0]['STATUS'] != 'APPROVED':
            return 'Sensor not approved', '', ''

        else:
            results = results[0]

        ip = results['ip']
        port = results['port']
        #ip, port = list(self.sensors.aggregate([ { "$match": { "SERVER": sensor }},{ "$project": { "ip": "$receiver", "port": "$receiver_port" }}]))
        cert = list(self.certs.aggregate([ { "$match": { "type": "receiver", "ip": ip }}, { "$project": { "cert": "$cert" }}]))[0]['cert']
        cert_file = NamedTemporaryFile(mode='w+b', suffix='.pem')
        cert_file.writelines(cert)
        cert_file.flush()
        return ip, port, cert_file
Beispiel #45
0
class ZoneFile(object):
    def __init__(self, lines=None, no_header=False):
        self._file = NamedTemporaryFile()
        if not no_header:
            self._file.write(DUMMY_ZONE_HEADER)
        if lines is not None:
            self.writelines(lines)
        self._file.flush()

    @property
    def name(self):
        return self._file.name

    def write(self, str):
        self._file.write(str)
        self._file.flush()

    def writelines(self, lines):
        self._file.writelines("%s\n" % line for line in lines)
        self._file.flush()
class TestTrio(object):
    """Test class for testing how the individual class behave"""
    
    def setup_class(self):
        """Setup a standard trio with extra column in the 'proband' row."""
        trio_lines = ['#Standard trio\n', 
                    '#FamilyID\tSampleID\tFather\tMother\tSex\tPhenotype\n', 
                    'healthyParentsAffectedSon\tproband\tfather\tmother\t1\n',
                    'healthyParentsAffectedSon\tmother\t0\t0\t2\t1\n', 
                    'healthyParentsAffectedSon\tfather\t0\t0\t1\t1\n'
                    ]
        self.trio_file = NamedTemporaryFile(mode='w+t', delete=False, suffix='.vcf')
        self.trio_file.writelines(trio_lines)
        self.trio_file.seek(0)
        self.trio_file.close()
        
    
    def test_standard_trio_proband_missing_column(self):
        """Test if the file is parsed in a correct way."""
        with pytest.raises(WrongLineFormat):
            family_parser = parser.FamilyParser(open(self.trio_file.name, 'r'))
Beispiel #47
0
    def is_installed(cls, version):
        from tempfile import NamedTemporaryFile

        ret = True
        try:
            temp = NamedTemporaryFile(suffix='.g')
            temp.writelines(['echo "version: "{version} \n', 'quit \n'])
            temp.seek(0)
            out = sp.check_output(
                ['genesis', '-nox', '-batch', '-notty', temp.name])
            m = re.search('version:' + '\s*([0-9]*\.?[0-9]+)\s*', out)
            if m:
                ret = m.groups()[0]
                inform("Found GENESIS in path, version %s" % ret,
                        verbosity=1, indent=2)
        except OSError:
            ret = False
        finally:
            temp.close()

        return ret
def convert_2_bed(txDict):
    """
    GIVEN:
    1) txDict: a dict of Tx g2g dicts
    
    DO:
    1) construct a bed entry for every Tx in list with start,stop,chrom,strand
    2) write this bed to a tmp file sorted by chrom:minCoord
    
    RETURN:
    1) path to out file
    """
    # NOTE: because gtf_to_genes uses the same coord scheme there is no need for coord
    # mangling in a simple BED representation (we dont care about exon blocks here)
    
    tmpFile = NamedTemporaryFile(mode='w+t',suffix=".bed",delete=False)
    
    bedLines = []
    
    for tx in txDict:
        chrom      = str(txDict[tx].gene.contig)
        chromStart = str(txDict[tx].beg)
        chromEnd   = str(txDict[tx].end)
        name       = str(txDict[tx].cdna_id)
        score      = str(999)
        if txDict[tx].gene.strand == True:
            strand = '+'
        else:
            strand = '-'
        
        bedLines.append([chrom,chromStart,chromEnd,name,score,strand])
        
    sortedLines = sorted(sorted(bedLines,key=lambda x: min(int(x[1]),int(x[2]))),key=lambda x: x[0])
    
    tmpFile.writelines(['\t'.join(l)+'\n' for l in sortedLines])
    tmpFile.close()
    
    return tmpFile
class DbConnection(object):
  '''Wraps a DB API 2 connection. Instances should only be obtained through the
     DbConnector.create_connection(...) method.

  '''

  @staticmethod
  def describe_common_tables(db_connections):
    '''Find and return a TableExprList containing Table objects that the given connections
       have in common.
    '''
    common_table_names = None
    for db_connection in db_connections:
      table_names = set(db_connection.list_table_names())
      if common_table_names is None:
        common_table_names = table_names
      else:
        common_table_names &= table_names
    common_table_names = sorted(common_table_names)

    tables = TableExprList()
    for table_name in common_table_names:
      common_table = None
      mismatch = False
      for db_connection in db_connections:
        table = db_connection.describe_table(table_name)
        if common_table is None:
          common_table = table
          continue
        if not table.cols:
          LOG.debug('%s has no remaining columns', table_name)
          mismatch = True
          break
        if len(common_table.cols) != len(table.cols):
          LOG.debug('Ignoring table %s.'
              ' It has a different number of columns across databases.', table_name)
          mismatch = True
          break
        for left, right in izip(common_table.cols, table.cols):
          if not left.name == right.name and left.type == right.type:
            LOG.debug('Ignoring table %s. It has different columns %s vs %s.' %
                (table_name, left, right))
            mismatch = True
            break
        if mismatch:
          break
      if not mismatch:
        tables.append(common_table)

    return tables

  SQL_TYPE_PATTERN = compile(r'([^()]+)(\((\d+,? ?)*\))?')
  TYPE_NAME_ALIASES = \
      dict((type_.name().upper(), type_.name().upper()) for type_ in EXACT_TYPES)
  TYPES_BY_NAME =  dict((type_.name().upper(), type_) for type_ in EXACT_TYPES)
  EXACT_TYPES_TO_SQL = dict((type_, type_.name().upper()) for type_ in EXACT_TYPES)

  def __init__(self, connector, connection, db_name=None):
    self.connector = connector
    self.connection = connection
    self.db_name = db_name

    self._bulk_load_table = None
    self._bulk_load_data_file = None   # If not set by the user, a temp file will be used
    self._bulk_load_col_delimiter = b'\x01'
    self._bulk_load_row_delimiter = '\n'
    self._bulk_load_null_val = '\\N'

  @property
  def db_type(self):
    return self.connector.db_type

  @property
  def supports_kill_connection(self):
    return False

  def kill_connection(self):
    '''Kill the current connection and any currently running queries associated with the
       connection.
    '''
    raise Exception('Killing connection is not supported')

  @property
  def supports_index_creation(self):
    return True

  def create_cursor(self):
    return DatabaseCursor(self.connection.cursor(), self)

  @contextmanager
  def open_cursor(self):
    '''Returns a new cursor for use in a "with" statement. When the "with" statement ends,
       the cursor will be closed.

    '''
    cursor = None
    try:
      cursor = self.create_cursor()
      yield cursor
    finally:
      self.close_cursor_quietly(cursor)

  def close_cursor_quietly(self, cursor):
    if cursor:
      try:
        cursor.close()
      except Exception as e:
        LOG.debug('Error closing cursor: %s', e, exc_info=True)

  def execute(self, sql):
    with self.open_cursor() as cursor:
      cursor.execute(sql)

  def execute_and_fetchall(self, sql):
    with self.open_cursor() as cursor:
      cursor.execute(sql)
      return cursor.fetchall()

  def close(self):
    '''Close the underlying connection.'''
    self.connection.close()

  def reconnect(self):
    try:
      self.close()
    except Exception as e:
      LOG.warn('Error closing connection: %s' % e)
    other = self.connector.create_connection(db_name=self.db_name)
    self.connection = other.connection

  #########################################
  # Databases                             #
  #########################################

  def list_db_names(self):
    '''Return a list of database names always in lowercase.'''
    rows = self.execute_and_fetchall(self.make_list_db_names_sql())
    return [row[0].lower() for row in rows]

  def make_list_db_names_sql(self):
    return 'SHOW DATABASES'

  def create_database(self, db_name):
    db_name = db_name.lower()
    with self.open_cursor() as cursor:
      cursor.execute('CREATE DATABASE ' + db_name)

  def drop_db_if_exists(self, db_name):
    '''This should not be called from a connection to the database being dropped.'''
    db_name = db_name.lower()
    if db_name not in self.list_db_names():
      return
    if self.db_name and self.db_name.lower() == db_name:
      raise Exception('Cannot drop database while still connected to it')
    self.drop_database(db_name)

  def drop_database(self, db_name):
    db_name = db_name.lower()
    self.execute('DROP DATABASE ' + db_name)

  #########################################
  # Tables                                #
  #########################################

  def list_table_names(self):
    '''Return a list of table names always in lowercase.'''
    rows = self.execute_and_fetchall(self.make_list_table_names_sql())
    return [row[0].lower() for row in rows]

  def make_list_table_names_sql(self):
    return 'SHOW TABLES'

  def describe_table(self, table_name):
    '''Return a Table with table and col names always in lowercase.'''
    rows = self.execute_and_fetchall(self.make_describe_table_sql(table_name))
    table = Table(table_name.lower())
    for row in rows:
      col_name, data_type = row[:2]
      match = self.SQL_TYPE_PATTERN.match(data_type)
      if not match:
        raise Exception('Unexpected data type format: %s' % data_type)
      type_name = self.TYPE_NAME_ALIASES.get(match.group(1).upper())
      if not type_name:
        raise Exception('Unknown data type: ' + match.group(1))
      if len(match.groups()) > 1 and match.group(2) is not None:
        type_size = [int(size) for size in match.group(2)[1:-1].split(',')]
      else:
        type_size = None
      table.cols.append(
          Column(table, col_name.lower(), self.parse_data_type(type_name, type_size)))
    self.load_unique_col_metadata(table)
    return table

  def make_describe_table_sql(self, table_name):
    return 'DESCRIBE ' + table_name

  def parse_data_type(self, type_name, type_size):
    if type_name in ('DECIMAL', 'NUMERIC'):
      return get_decimal_class(*type_size)
    if type_name == 'CHAR':
      return get_char_class(*type_size)
    if type_name == 'VARCHAR':
      if type_size and type_size[0] <= VarChar.MAX:
        return get_varchar_class(*type_size)
      type_name = 'STRING'
    return self.TYPES_BY_NAME[type_name]

  def create_table(self, table):
    if not table.cols:
      raise Exception('At least one col is required')
    table_sql = self.make_create_table_sql(table)
    self.execute(table_sql)

  def make_create_table_sql(self, table):
    sql = 'CREATE TABLE %s (%s)' % (
        table.name,
        ', '.join('%s %s' %
            (col.name, self.get_sql_for_data_type(col.exact_type)) +
            ('' if (self.db_type == IMPALA or self.db_type == HIVE) else ' NULL')
            for col in table.cols))
    return sql

  def get_sql_for_data_type(self, data_type):
    if issubclass(data_type, VarChar):
      return 'VARCHAR(%s)' % data_type.MAX
    if issubclass(data_type, Char):
      return 'CHAR(%s)' % data_type.MAX
    if issubclass(data_type, Decimal):
      return 'DECIMAL(%s, %s)' % (data_type.MAX_DIGITS, data_type.MAX_FRACTIONAL_DIGITS)
    return self.EXACT_TYPES_TO_SQL[data_type]

  def drop_table(self, table_name, if_exists=True):
    self.execute('DROP TABLE IF EXISTS ' + table_name.lower())

  def drop_view(self, view_name, if_exists=True):
    self.execute('DROP VIEW IF EXISTS ' + view_name.lower())

  def index_table(self, table_name):
    table = self.describe_table(table_name)
    with self.open_cursor() as cursor:
      for col in table.cols:
        index_name = '%s_%s' % (table_name, col.name)
        if self.db_name:
          index_name = '%s_%s' % (self.db_name, index_name)
        cursor.execute('CREATE INDEX %s ON %s(%s)' % (index_name, table_name, col.name))

  #########################################
  # Data loading                          #
  #########################################

  def make_insert_sql_from_data(self, table, rows):
    if not rows:
      raise Exception('At least one row is required')
    if not table.cols:
      raise Exception('At least one col is required')

    sql = 'INSERT INTO %s VALUES ' % table.name
    for row_idx, row in enumerate(rows):
      if row_idx > 0:
        sql += ', '
      sql += '('
      for col_idx, col in enumerate(table.cols):
        if col_idx > 0:
          sql += ', '
        val = row[col_idx]
        if val is None:
          sql += 'NULL'
        elif issubclass(col.type, Timestamp):
          sql += "TIMESTAMP '%s'" % val
        elif issubclass(col.type, Char):
          sql += "'%s'" % val.replace("'", "''")
        else:
          sql += str(val)
      sql += ')'

    return sql

  def begin_bulk_load_table(self, table, create_tables):
    if create_tables:
      self.create_table(table)
    self._bulk_load_table = table
    if not self._bulk_load_data_file:
      self._bulk_load_data_file = NamedTemporaryFile()

  def handle_bulk_load_table_data(self, rows):
    if not rows:
      raise Exception('At least one row is required')

    data = list()
    for row in rows:
      for col_idx, col in enumerate(self._bulk_load_table.cols):
        if col_idx > 0:
          data.append(self._bulk_load_col_delimiter)
        val = row[col_idx]
        if val is None:
          file_val = self._bulk_load_null_val
        else:
          file_val = str(val)
        data.append(file_val)
      data.append(self._bulk_load_row_delimiter)
    if data:
      self._bulk_load_data_file.writelines(data)

  def end_bulk_load_table(self, create_tables):
    self._bulk_load_data_file.flush()

  #########################################
  # Data analysis                         #
  #########################################

  def search_for_unique_cols(self, table=None, table_name=None, depth=2):
    if not table:
      table = self.describe_table(table_name)
    sql_templ = 'SELECT COUNT(*) FROM %s GROUP BY %%s HAVING COUNT(*) > 1' % table.name
    unique_cols = list()
    with self.open_cursor() as cursor:
      for current_depth in xrange(1, depth + 1):
        for cols in combinations(table.cols, current_depth):   # redundant combos excluded
          cols = set(cols)
          if any(ifilter(lambda unique_subset: unique_subset < cols, unique_cols)):
            # cols contains a combo known to be unique
            continue
          col_names = ', '.join(col.name for col in cols)
          sql = sql_templ % col_names
          LOG.debug('Checking column combo (%s) for uniqueness' % col_names)
          cursor.execute(sql)
          if not cursor.fetchone():
            LOG.debug('Found unique column combo (%s)' % col_names)
            unique_cols.append(cols)
    return unique_cols

  def persist_unique_col_metadata(self, table):
    if not table.unique_cols:
      return
    with closing(shelve.open('/tmp/query_generator.shelve', writeback=True)) as store:
      if self.db_type not in store:
        store[self.db_type] = dict()
      db_type_store = store[self.db_type]
      if self.db_name not in db_type_store:
        db_type_store[self.db_name] = dict()
      db_store = db_type_store[self.db_name]
      db_store[table.name] = [[col.name for col in cols] for cols in table.unique_cols]

  def load_unique_col_metadata(self, table):
    with closing(shelve.open('/tmp/query_generator.shelve')) as store:
      db_type_store = store.get(self.db_type)
      if not db_type_store:
        return
      db_store = db_type_store.get(self.db_name)
      if not db_store:
        return
      unique_col_names = db_store.get(table.name)
      if not unique_col_names:
        return
      unique_cols = list()
      for entry in unique_col_names:
        col_names = set(entry)
        cols = set((col for col in table.cols if col.name in col_names))
        if len(col_names) != len(cols):
          raise Exception("Incorrect unique column data for %s" % table.name)
        unique_cols.append(cols)
      table.unique_cols = unique_cols
Beispiel #50
0
 def make_file(lines):
     f = NamedTemporaryFile()
     f.writelines(lines)
     f.flush()
     return f
class _FopenClass(object):
    
    def __init__(self, path, mode):
        '''
        @param path: file path
        @param mode: file open mode
        '''
        self.master = {
            'path' : abspath(path),
            'mode' : mode
        }
        
    def __del__(self):
        '''
        @summary: destructor
        '''
        self.ftemp.close()
        
    def __enter__(self):
        ''' 
        @summary: enter with-block
        '''
        self.start()
        return self
    
    def __exit__(self, exc_type, exc_value, traceback):
        '''
        @summary: exit with-block
        '''
        if exc_type:
            self.rollback()
        else:
            self.commit()
            
    def start(self):
        '''
        @summary: start transaction
        '''
        try:
            self.ftemp = NamedTemporaryFile()
        except:
            raise
    
    def write(self, value):
        '''
        @summary: write temp file.
        '''
        try:
            self.ftemp.write(value)
        except:
            raise
        
    def writelines(self, seq):
        '''
        @summary: write temp file.
        '''
        try:
            self.ftemp.writelines(seq)
        except:
            raise
    
    def rollback(self):
        ''' 
        @summary: rollback
        '''
        try:
            self.ftemp.close()
        except:
            raise
    
    def commit(self):
        '''
        @summary: commit transaction
        '''
        try:
            self.ftemp.flush()
            self.ftemp.seek(0)
            copy2(self.ftemp.name, self.master['path'])
        except:
            raise
Beispiel #52
0
def main(apiurl, opts, argv):

    repo = argv[0]
    arch = argv[1]
    build_descr = argv[2]
    xp = []
    build_root = None
    cache_dir  = None
    build_uid = ''
    vm_type = config['build-type']
    vm_telnet = None

    build_descr = os.path.abspath(build_descr)
    build_type = os.path.splitext(build_descr)[1][1:]
    if os.path.basename(build_descr) == 'PKGBUILD':
        build_type = 'arch'
    if os.path.basename(build_descr) == 'build.collax':
        build_type = 'collax'
    if build_type not in ['spec', 'dsc', 'kiwi', 'arch', 'collax', 'livebuild']:
        raise oscerr.WrongArgs(
                'Unknown build type: \'%s\'. Build description should end in .spec, .dsc, .kiwi or .livebuild.' \
                        % build_type)
    if not os.path.isfile(build_descr):
        raise oscerr.WrongArgs('Error: build description file named \'%s\' does not exist.' % build_descr)

    buildargs = []
    if not opts.userootforbuild:
        buildargs.append('--norootforbuild')
    if opts.clean:
        buildargs.append('--clean')
    if opts.noinit:
        buildargs.append('--noinit')
    if opts.nochecks:
        buildargs.append('--no-checks')
    if not opts.no_changelog:
        buildargs.append('--changelog')
    if opts.root:
        build_root = opts.root
    if opts.target:
        buildargs.append('--target=%s' % opts.target)
    if opts.threads:
        buildargs.append('--threads=%s' % opts.threads)
    if opts.jobs:
        buildargs.append('--jobs=%s' % opts.jobs)
    elif config['build-jobs'] > 1:
        buildargs.append('--jobs=%s' % config['build-jobs'])
    if opts.icecream or config['icecream'] != '0':
        if opts.icecream:
            num = opts.icecream
        else:
            num = config['icecream']

        if int(num) > 0:
            buildargs.append('--icecream=%s' % num)
            xp.append('icecream')
            xp.append('gcc-c++')
    if opts.ccache:
        buildargs.append('--ccache')
        xp.append('ccache')
    if opts.linksources:
        buildargs.append('--linksources')
    if opts.baselibs:
        buildargs.append('--baselibs')
    if opts.debuginfo:
        buildargs.append('--debug')
    if opts._with:
        for o in opts._with:
            buildargs.append('--with=%s' % o)
    if opts.without:
        for o in opts.without:
            buildargs.append('--without=%s' % o)
    if opts.define:
        for o in opts.define:
            buildargs.append('--define=%s' % o)
    if config['build-uid']:
        build_uid = config['build-uid']
    if opts.build_uid:
        build_uid = opts.build_uid
    if build_uid:
        buildidre = re.compile('^[0-9]{1,5}:[0-9]{1,5}$')
        if build_uid == 'caller':
            buildargs.append('--uid=%s:%s' % (os.getuid(), os.getgid()))
        elif buildidre.match(build_uid):
            buildargs.append('--uid=%s' % build_uid)
        else:
            print('Error: build-uid arg must be 2 colon separated numerics: "uid:gid" or "caller"', file=sys.stderr)
            return 1
    if opts.vm_type:
        vm_type = opts.vm_type
    if opts.vm_telnet:
        vm_telnet = opts.vm_telnet
    if opts.alternative_project:
        prj = opts.alternative_project
        pac = '_repository'
    else:
        prj = store_read_project(os.curdir)
        if opts.local_package:
            pac = '_repository'
        else:
            pac = store_read_package(os.curdir)
    if opts.shell:
        buildargs.append("--shell")

    orig_build_root = config['build-root']
    # make it possible to override configuration of the rc file
    for var in ['OSC_PACKAGECACHEDIR', 'OSC_SU_WRAPPER', 'OSC_BUILD_ROOT']:
        val = os.getenv(var)
        if val:
            if var.startswith('OSC_'): var = var[4:]
            var = var.lower().replace('_', '-')
            if var in config:
                print('Overriding config value for %s=\'%s\' with \'%s\'' % (var, config[var], val))
            config[var] = val

    pacname = pac
    if pacname == '_repository':
        if not opts.local_package:
            try:
                pacname = store_read_package(os.curdir)
            except oscerr.NoWorkingCopy:
                opts.local_package = True
        if opts.local_package:
            pacname = os.path.splitext(build_descr)[0]
    apihost = urlsplit(apiurl)[1]
    if not build_root:
        build_root = config['build-root']
        if build_root == orig_build_root:
            # ENV var was not set
            build_root = config['api_host_options'][apiurl].get('build-root', build_root)
        try:
            build_root = build_root % {'repo': repo, 'arch': arch,
                         'project': prj, 'package': pacname, 'apihost': apihost}
        except:
            pass

    cache_dir = config['packagecachedir'] % {'apihost': apihost}

    extra_pkgs = []
    if not opts.extra_pkgs:
        extra_pkgs = config['extra-pkgs']
    elif opts.extra_pkgs != ['']:
        extra_pkgs = opts.extra_pkgs

    if xp:
        extra_pkgs += xp

    prefer_pkgs = {}
    build_descr_data = open(build_descr).read()

    # XXX: dirty hack but there's no api to provide custom defines
    if opts.without:
        s = ''
        for i in opts.without:
            s += "%%define _without_%s 1\n" % i
        build_descr_data = s + build_descr_data
    if opts._with:
        s = ''
        for i in opts._with:
            s += "%%define _with_%s 1\n" % i
        build_descr_data = s + build_descr_data
    if opts.define:
        s = ''
        for i in opts.define:
            s += "%%define %s\n" % i
        build_descr_data = s + build_descr_data

    cpiodata = None
    servicefile = os.path.join(os.path.dirname(build_descr), "_service")
    if not os.path.isfile(servicefile):
        servicefile = os.path.join(os.path.dirname(build_descr), "_service")
        if not os.path.isfile(servicefile):
            servicefile = None
        else:
            print('Using local _service file')
    buildenvfile = os.path.join(os.path.dirname(build_descr), "_buildenv." + repo + "." + arch)
    if not os.path.isfile(buildenvfile):
        buildenvfile = os.path.join(os.path.dirname(build_descr), "_buildenv")
        if not os.path.isfile(buildenvfile):
            buildenvfile = None
        else:
            print('Using local buildenv file: %s' % os.path.basename(buildenvfile))
    if buildenvfile or servicefile:
        from .util import cpio
        if not cpiodata:
            cpiodata = cpio.CpioWrite()

    if opts.prefer_pkgs:
        print('Scanning the following dirs for local packages: %s' % ', '.join(opts.prefer_pkgs))
        from .util import cpio
        if not cpiodata:
            cpiodata = cpio.CpioWrite()
        prefer_pkgs = get_prefer_pkgs(opts.prefer_pkgs, arch, build_type, cpiodata)

    if cpiodata:
        cpiodata.add(os.path.basename(build_descr), build_descr_data)
        # buildenv must come last for compatibility reasons...
        if buildenvfile:
            cpiodata.add("buildenv", open(buildenvfile).read())
        if servicefile:
            cpiodata.add("_service", open(servicefile).read())
        build_descr_data = cpiodata.get()

    # special handling for overlay and rsync-src/dest
    specialcmdopts = []
    if opts.rsyncsrc or opts.rsyncdest :
        if not opts.rsyncsrc or not opts.rsyncdest:
            raise oscerr.WrongOptions('When using --rsync-{src,dest} both parameters have to be specified.')
        myrsyncsrc = os.path.abspath(os.path.expanduser(os.path.expandvars(opts.rsyncsrc)))
        if not os.path.isdir(myrsyncsrc):
            raise oscerr.WrongOptions('--rsync-src %s is no valid directory!' % opts.rsyncsrc)
        # can't check destination - its in the target chroot ;) - but we can check for sanity
        myrsyncdest = os.path.expandvars(opts.rsyncdest)
        if not os.path.isabs(myrsyncdest):
            raise oscerr.WrongOptions('--rsync-dest %s is no absolute path (starting with \'/\')!' % opts.rsyncdest)
        specialcmdopts = ['--rsync-src='+myrsyncsrc, '--rsync-dest='+myrsyncdest]
    if opts.overlay:
        myoverlay = os.path.abspath(os.path.expanduser(os.path.expandvars(opts.overlay)))
        if not os.path.isdir(myoverlay):
            raise oscerr.WrongOptions('--overlay %s is no valid directory!' % opts.overlay)
        specialcmdopts += ['--overlay='+myoverlay]

    bi_file = None
    bc_file = None
    bi_filename = '_buildinfo-%s-%s.xml' % (repo, arch)
    bc_filename = '_buildconfig-%s-%s' % (repo, arch)
    if is_package_dir('.') and os.access(osc.core.store, os.W_OK):
        bi_filename = os.path.join(os.getcwd(), osc.core.store, bi_filename)
        bc_filename = os.path.join(os.getcwd(), osc.core.store, bc_filename)
    elif not os.access('.', os.W_OK):
        bi_file = NamedTemporaryFile(prefix=bi_filename)
        bi_filename = bi_file.name
        bc_file = NamedTemporaryFile(prefix=bc_filename)
        bc_filename = bc_file.name
    else:
        bi_filename = os.path.abspath(bi_filename)
        bc_filename = os.path.abspath(bc_filename)

    try:
        if opts.noinit:
            if not os.path.isfile(bi_filename):
                raise oscerr.WrongOptions('--noinit is not possible, no local buildinfo file')
            print('Use local \'%s\' file as buildinfo' % bi_filename)
            if not os.path.isfile(bc_filename):
                raise oscerr.WrongOptions('--noinit is not possible, no local buildconfig file')
            print('Use local \'%s\' file as buildconfig' % bc_filename)
        elif opts.offline:
            if not os.path.isfile(bi_filename):
                raise oscerr.WrongOptions('--offline is not possible, no local buildinfo file')
            print('Use local \'%s\' file as buildinfo' % bi_filename)
            if not os.path.isfile(bc_filename):
                raise oscerr.WrongOptions('--offline is not possible, no local buildconfig file')
        else:
            print('Getting buildinfo from server and store to %s' % bi_filename)
            bi_text = ''.join(get_buildinfo(apiurl,
                                            prj,
                                            pac,
                                            repo,
                                            arch,
                                            specfile=build_descr_data,
                                            addlist=extra_pkgs))
            if not bi_file:
                bi_file = open(bi_filename, 'w')
            # maybe we should check for errors before saving the file
            bi_file.write(bi_text)
            bi_file.flush()
            print('Getting buildconfig from server and store to %s' % bc_filename)
            bc = get_buildconfig(apiurl, prj, repo)
            if not bc_file:
                bc_file = open(bc_filename, 'w')
            bc_file.write(bc)
            bc_file.flush()
    except HTTPError as e:
        if e.code == 404:
            # check what caused the 404
            if meta_exists(metatype='prj', path_args=(quote_plus(prj), ),
                           template_args=None, create_new=False, apiurl=apiurl):
                pkg_meta_e = None
                try:
                    # take care, not to run into double trouble.
                    pkg_meta_e = meta_exists(metatype='pkg', path_args=(quote_plus(prj),
                                        quote_plus(pac)), template_args=None, create_new=False,
                                        apiurl=apiurl)
                except:
                    pass

                if pkg_meta_e:
                    print('ERROR: Either wrong repo/arch as parameter or a parse error of .spec/.dsc/.kiwi file due to syntax error', file=sys.stderr)
                else:
                    print('The package \'%s\' does not exist - please ' \
                                        'rerun with \'--local-package\'' % pac, file=sys.stderr)
            else:
                print('The project \'%s\' does not exist - please ' \
                                    'rerun with \'--alternative-project <alternative_project>\'' % prj, file=sys.stderr)
            sys.exit(1)
        else:
            raise

    bi = Buildinfo(bi_filename, apiurl, build_type, list(prefer_pkgs.keys()))

    if bi.debuginfo and not (opts.disable_debuginfo or '--debug' in buildargs):
        buildargs.append('--debug')

    if opts.release:
        bi.release = opts.release

    if bi.release:
        buildargs.append('--release=%s' % bi.release)

    # real arch of this machine
    # vs.
    # arch we are supposed to build for
    if bi.hostarch != None:
        if hostarch != bi.hostarch and not bi.hostarch in can_also_build.get(hostarch, []):
            print('Error: hostarch \'%s\' is required.' % (bi.hostarch), file=sys.stderr)
            return 1
    elif hostarch != bi.buildarch:
        if not bi.buildarch in can_also_build.get(hostarch, []):
            # OBSOLETE: qemu_can_build should not be needed anymore since OBS 2.3
            if vm_type != "emulator" and not bi.buildarch in qemu_can_build:
                print('Error: hostarch \'%s\' cannot build \'%s\'.' % (hostarch, bi.buildarch), file=sys.stderr)
                return 1
            print('WARNING: It is guessed to build on hostarch \'%s\' for \'%s\' via QEMU.' % (hostarch, bi.buildarch), file=sys.stderr)

    rpmlist_prefers = []
    if prefer_pkgs:
        print('Evaluating preferred packages')
        for name, path in prefer_pkgs.items():
            if bi.has_dep(name):
                # We remove a preferred package from the buildinfo, so that the
                # fetcher doesn't take care about them.
                # Instead, we put it in a list which is appended to the rpmlist later.
                # At the same time, this will make sure that these packages are
                # not verified.
                bi.remove_dep(name)
                rpmlist_prefers.append((name, path))
                print(' - %s (%s)' % (name, path))

    print('Updating cache of required packages')

    urllist = []
    if not opts.download_api_only:
        # transform 'url1, url2, url3' form into a list
        if 'urllist' in config:
            if isinstance(config['urllist'], str):
                re_clist = re.compile('[, ]+')
                urllist = [ i.strip() for i in re_clist.split(config['urllist'].strip()) ]
            else:
                urllist = config['urllist']

        # OBS 1.5 and before has no downloadurl defined in buildinfo
        if bi.downloadurl:
            urllist.append(bi.downloadurl + '/%(extproject)s/%(extrepository)s/%(arch)s/%(filename)s')
    if opts.disable_cpio_bulk_download:
        urllist.append( '%(apiurl)s/build/%(project)s/%(repository)s/%(repoarch)s/%(repopackage)s/%(repofilename)s' )

    fetcher = Fetcher(cache_dir,
                      urllist = urllist,
                      api_host_options = config['api_host_options'],
                      offline = opts.noinit or opts.offline,
                      http_debug = config['http_debug'],
                      enable_cpio = not opts.disable_cpio_bulk_download,
                      cookiejar=cookiejar)

    if not opts.trust_all_projects:
        # implicitly trust the project we are building for
        check_trusted_projects(apiurl, [ i for i in bi.projects.keys() if not i == prj ])

    imagefile = ''
    imagesource = ''
    imagebins = []
    if (not config['no_preinstallimage'] and not opts.nopreinstallimage and
        bi.preinstallimage and
        not opts.noinit and not opts.offline and
        (opts.clean or (not os.path.exists(build_root + "/installed-pkg") and
                        not os.path.exists(build_root + "/.build/init_buildsystem.data")))):
        (imagefile, imagesource, imagebins) = get_preinstall_image(apiurl, arch, cache_dir, bi.preinstallimage)
        if imagefile:
            # remove binaries from build deps which are included in preinstall image
            for i in bi.deps:
                if i.name in imagebins:
                    bi.remove_dep(i.name)

    # now update the package cache
    fetcher.run(bi)

    old_pkg_dir = None
    if opts.oldpackages:
        old_pkg_dir = opts.oldpackages
        if not old_pkg_dir.startswith('/') and not opts.offline:
            data = [ prj, pacname, repo, arch]
            if old_pkg_dir == '_link':
                p = osc.core.findpacs(os.curdir)[0]
                if not p.islink():
                    raise oscerr.WrongOptions('package is not a link')
                data[0] = p.linkinfo.project
                data[1] = p.linkinfo.package
                repos = osc.core.get_repositories_of_project(apiurl, data[0])
                # hack for links to e.g. Factory
                if not data[2] in repos and 'standard' in repos:
                    data[2] = 'standard'
            elif old_pkg_dir != '' and old_pkg_dir != '_self':
                a = old_pkg_dir.split('/')
                for i in range(0, len(a)):
                    data[i] = a[i]

            destdir = os.path.join(cache_dir, data[0], data[2], data[3])
            old_pkg_dir = None
            try:
                print("Downloading previous build from %s ..." % '/'.join(data))
                binaries = get_binarylist(apiurl, data[0], data[2], data[3], package=data[1], verbose=True)
            except Exception as e:
                print("Error: failed to get binaries: %s" % str(e))
                binaries = []

            if binaries:
                class mytmpdir:
                    """ temporary directory that removes itself"""
                    def __init__(self, *args, **kwargs):
                        self.name = mkdtemp(*args, **kwargs)
                    _rmtree = staticmethod(shutil.rmtree)
                    def cleanup(self):
                        self._rmtree(self.name)
                    def __del__(self):
                        self.cleanup()
                    def __exit__(self):
                        self.cleanup()
                    def __str__(self):
                        return self.name

                old_pkg_dir = mytmpdir(prefix='.build.oldpackages', dir=os.path.abspath(os.curdir))
                if not os.path.exists(destdir):
                    os.makedirs(destdir)
            for i in binaries:
                fname = os.path.join(destdir, i.name)
                os.symlink(fname, os.path.join(str(old_pkg_dir), i.name))
                if os.path.exists(fname):
                    st = os.stat(fname)
                    if st.st_mtime == i.mtime and st.st_size == i.size:
                        continue
                get_binary_file(apiurl,
                                data[0],
                                data[2], data[3],
                                i.name,
                                package = data[1],
                                target_filename = fname,
                                target_mtime = i.mtime,
                                progress_meter = True)

        if old_pkg_dir != None:
            buildargs.append('--oldpackages=%s' % old_pkg_dir)

    # Make packages from buildinfo available as repos for kiwi
    if build_type == 'kiwi':
        if os.path.exists('repos'):
            shutil.rmtree('repos')
        os.mkdir('repos')
        for i in bi.deps:
            if not i.extproject:
                # remove
                bi.deps.remove(i)
                continue
            # project
            pdir = str(i.extproject).replace(':/', ':')
            # repo
            rdir = str(i.extrepository).replace(':/', ':')
            # arch
            adir = i.repoarch
            # project/repo
            prdir = "repos/"+pdir+"/"+rdir
            # project/repo/arch
            pradir = prdir+"/"+adir
            # source fullfilename
            sffn = i.fullfilename
            filename = sffn.split("/")[-1]
            # target fullfilename
            tffn = pradir+"/"+filename
            if not os.path.exists(os.path.join(pradir)):
                os.makedirs(os.path.join(pradir))
            if not os.path.exists(tffn):
                print("Using package: "+sffn)
                if opts.linksources:
                    os.link(sffn, tffn)
                else:
                    os.symlink(sffn, tffn)
            if prefer_pkgs:
                for name, path in prefer_pkgs.items():
                    if name == filename:
                        print("Using prefered package: " + path + "/" + filename)
                        os.unlink(tffn)
                        if opts.linksources:
                            os.link(path + "/" + filename, tffn)
                        else:
                            os.symlink(path + "/" + filename, tffn)
        # Is a obsrepositories tag used?
        try:
            tree = ET.parse(build_descr)
        except:
            print('could not parse the kiwi file:', file=sys.stderr)
            print(open(build_descr).read(), file=sys.stderr)
            sys.exit(1)
        root = tree.getroot()
        # product
        for xml in root.findall('instsource'):
            if xml.find('instrepo').find('source').get('path') == 'obsrepositories:/':
                print("obsrepositories:/ for product builds is not yet supported in osc!")
                sys.exit(1)
        # appliance
        expand_obsrepos=None
        for xml in root.findall('repository'):
            if xml.find('source').get('path') == 'obsrepositories:/':
                expand_obsrepos=True
        if expand_obsrepos:
          buildargs.append('--kiwi-parameter')
          buildargs.append('--ignore-repos')
          for xml in root.findall('repository'):
              if xml.find('source').get('path') == 'obsrepositories:/':
                  for path in bi.pathes:
                      if not os.path.isdir("repos/"+path):
                          continue
                      buildargs.append('--kiwi-parameter')
                      buildargs.append('--add-repo')
                      buildargs.append('--kiwi-parameter')
                      buildargs.append("repos/"+path)
                      buildargs.append('--kiwi-parameter')
                      buildargs.append('--add-repotype')
                      buildargs.append('--kiwi-parameter')
                      buildargs.append('rpm-md')
                      if xml.get('priority'):
                          buildargs.append('--kiwi-parameter')
                          buildargs.append('--add-repoprio='+xml.get('priority'))
              else:
                   m = re.match(r"obs://[^/]+/([^/]+)/(\S+)", xml.find('source').get('path'))
                   if not m:
                       # short path without obs instance name
                       m = re.match(r"obs://([^/]+)/(.+)", xml.find('source').get('path'))
                   project=m.group(1).replace(":", ":/")
                   repo=m.group(2)
                   buildargs.append('--kiwi-parameter')
                   buildargs.append('--add-repo')
                   buildargs.append('--kiwi-parameter')
                   buildargs.append("repos/"+project+"/"+repo)
                   buildargs.append('--kiwi-parameter')
                   buildargs.append('--add-repotype')
                   buildargs.append('--kiwi-parameter')
                   buildargs.append('rpm-md')
                   if xml.get('priority'):
                       buildargs.append('--kiwi-parameter')
                       buildargs.append('--add-repopriority='+xml.get('priority'))

    if vm_type == "xen" or vm_type == "kvm" or vm_type == "lxc":
        print('Skipping verification of package signatures due to secure VM build')
    elif bi.pacsuffix == 'rpm':
        if opts.no_verify:
            print('Skipping verification of package signatures')
        else:
            print('Verifying integrity of cached packages')
            verify_pacs(bi)
    elif bi.pacsuffix == 'deb':
        if opts.no_verify or opts.noinit:
            print('Skipping verification of package signatures')
        else:
            print('WARNING: deb packages get not verified, they can compromise your system !')
    else:
        print('WARNING: unknown packages get not verified, they can compromise your system !')

    for i in bi.deps:
        if i.hdrmd5:
            from .util import packagequery
            hdrmd5 = packagequery.PackageQuery.queryhdrmd5(i.fullfilename)
            if not hdrmd5:
                print("Error: cannot get hdrmd5 for %s" % i.fullfilename)
                sys.exit(1)
            if hdrmd5 != i.hdrmd5:
                print("Error: hdrmd5 mismatch for %s: %s != %s" % (i.fullfilename, hdrmd5, i.hdrmd5))
                sys.exit(1)

    print('Writing build configuration')

    if build_type == 'kiwi':
        rpmlist = [ '%s %s\n' % (i.name, i.fullfilename) for i in bi.deps if not i.noinstall ]
    else:
        rpmlist = [ '%s %s\n' % (i.name, i.fullfilename) for i in bi.deps ]
    for i in imagebins:
        rpmlist.append('%s preinstallimage\n' % i)
    rpmlist += [ '%s %s\n' % (i[0], i[1]) for i in rpmlist_prefers ]

    if imagefile:
        rpmlist.append('preinstallimage: %s\n' % imagefile)
    if imagesource:
        rpmlist.append('preinstallimagesource: %s\n' % imagesource)

    rpmlist.append('preinstall: ' + ' '.join(bi.preinstall_list) + '\n')
    rpmlist.append('vminstall: ' + ' '.join(bi.vminstall_list) + '\n')
    rpmlist.append('runscripts: ' + ' '.join(bi.runscripts_list) + '\n')
    if build_type != 'kiwi' and bi.noinstall_list:
        rpmlist.append('noinstall: ' + ' '.join(bi.noinstall_list) + '\n')
    if build_type != 'kiwi' and bi.installonly_list:
        rpmlist.append('installonly: ' + ' '.join(bi.installonly_list) + '\n')

    rpmlist_file = NamedTemporaryFile(prefix='rpmlist.')
    rpmlist_filename = rpmlist_file.name
    rpmlist_file.writelines(rpmlist)
    rpmlist_file.flush()

    subst = { 'repo': repo, 'arch': arch, 'project' : prj, 'package' : pacname }
    vm_options = []
    # XXX check if build-device present
    my_build_device = ''
    if config['build-device']:
        my_build_device = config['build-device'] % subst
    else:
        # obs worker uses /root here but that collides with the
        # /root directory if the build root was used without vm
        # before
        my_build_device = build_root + '/img'

    need_root = True
    if vm_type:
        if config['build-swap']:
            my_build_swap = config['build-swap'] % subst
        else:
            my_build_swap = build_root + '/swap'

        vm_options = [ '--vm-type=%s' % vm_type ]
        if vm_telnet:
            vm_options += [ '--vm-telnet=' + vm_telnet ]
        if config['build-memory']:
            vm_options += [ '--memory=' + config['build-memory'] ]
        if vm_type != 'lxc':
            vm_options += [ '--vm-disk=' + my_build_device ]
            vm_options += [ '--vm-swap=' + my_build_swap ]
            vm_options += [ '--logfile=%s/.build.log' % build_root ]
            if vm_type == 'kvm':
                if os.access(build_root, os.W_OK) and os.access('/dev/kvm', os.W_OK):
                    # so let's hope there's also an fstab entry
                    need_root = False
                if config['build-kernel']:
                    vm_options += [ '--vm-kernel=' + config['build-kernel'] ]
                if config['build-initrd']:
                    vm_options += [ '--vm-initrd=' + config['build-initrd'] ]

            build_root += '/.mount'

        if config['build-memory']:
            vm_options += [ '--memory=' + config['build-memory'] ]
        if config['build-vmdisk-rootsize']:
            vm_options += [ '--vmdisk-rootsize=' + config['build-vmdisk-rootsize'] ]
        if config['build-vmdisk-swapsize']:
            vm_options += [ '--vmdisk-swapsize=' + config['build-vmdisk-swapsize'] ]
        if config['build-vmdisk-filesystem']:
            vm_options += [ '--vmdisk-filesystem=' + config['build-vmdisk-filesystem'] ]
        if config['build-vm-user']:
            vm_options += [ '--vm-user='******'build-vm-user'] ]


    if opts.preload:
        print("Preload done for selected repo/arch.")
        sys.exit(0)

    print('Running build')
    cmd = [ config['build-cmd'], '--root='+build_root,
                    '--rpmlist='+rpmlist_filename,
                    '--dist='+bc_filename,
                    '--arch='+bi.buildarch ]
    cmd += specialcmdopts + vm_options + buildargs
    cmd += [ build_descr ]

    if need_root:
        sucmd = config['su-wrapper'].split()
        if sucmd[0] == 'su':
            if sucmd[-1] == '-c':
                sucmd.pop()
            cmd = sucmd + ['-s', cmd[0], 'root', '--' ] + cmd[1:]
        else:
            cmd = sucmd + cmd

    # change personality, if needed
    if hostarch != bi.buildarch and bi.buildarch in change_personality:
        cmd = [ change_personality[bi.buildarch] ] + cmd

    try:
        rc = run_external(cmd[0], *cmd[1:])
        if rc:
            print()
            print('The buildroot was:', build_root)
            sys.exit(rc)
    except KeyboardInterrupt as i:
        print("keyboard interrupt, killing build ...")
        cmd.append('--kill')
        run_external(cmd[0], *cmd[1:])
        raise i

    pacdir = os.path.join(build_root, '.build.packages')
    if os.path.islink(pacdir):
        pacdir = os.readlink(pacdir)
        pacdir = os.path.join(build_root, pacdir)

    if os.path.exists(pacdir):
        (s_built, b_built) = get_built_files(pacdir, bi.buildtype)

        print()
        if s_built: print(s_built)
        print()
        print(b_built)

        if opts.keep_pkgs:
            for i in b_built.splitlines() + s_built.splitlines():
                shutil.copy2(i, os.path.join(opts.keep_pkgs, os.path.basename(i)))

    if bi_file:
        bi_file.close()
    if bc_file:
        bc_file.close()
    rpmlist_file.close()
Beispiel #53
0
        print 'WARNING: unknown packages get not verified, they can compromise your system !'

    print 'Writing build configuration'

    rpmlist = [ '%s %s\n' % (i.name, i.fullfilename) for i in bi.deps if not i.noinstall ]
    rpmlist += [ '%s %s\n' % (i[0], i[1]) for i in rpmlist_prefers ]

    rpmlist.append('preinstall: ' + ' '.join(bi.preinstall_list) + '\n')
    rpmlist.append('vminstall: ' + ' '.join(bi.vminstall_list) + '\n')
    rpmlist.append('cbinstall: ' + ' '.join(bi.cbinstall_list) + '\n')
    rpmlist.append('cbpreinstall: ' + ' '.join(bi.cbpreinstall_list) + '\n')
    rpmlist.append('runscripts: ' + ' '.join(bi.runscripts_list) + '\n')

    rpmlist_file = NamedTemporaryFile(prefix='rpmlist.')
    rpmlist_filename = rpmlist_file.name
    rpmlist_file.writelines(rpmlist)
    rpmlist_file.flush()

    subst = { 'repo': repo, 'arch': arch, 'project' : prj, 'package' : pacname }
    vm_options = []
    # XXX check if build-device present
    my_build_device = ''
    if config['build-device']:
        my_build_device = config['build-device'] % subst
    else:
        # obs worker uses /root here but that collides with the
        # /root directory if the build root was used without vm
        # before
        my_build_device = build_root + '/img'

    need_root = True
class Analyzer:
    def __init__(self, sam_path, use_temp=True, sortn=False):
        print "Name: %s"%sam_path
        self.samdirname = os.path.splitext(sam_path)[0]
        self.patientid = sam_path.split(os.sep)[-1][:-4]
        print 'dirname: %s'%self.samdirname
        self.pro_reads = self.samdirname + '_pro_reads'
        self.pro_counts = self.samdirname + '_pro_counts'
        self.gag_counts = self.samdirname + '_gag_counts'
        self.pro_pair_counts = self.samdirname + '_pro_pair_counts'
        self.pro_plts = self.samdirname + '_pro_plts'
        self.gag_plts = self.samdirname + '_gag_plts'
        self.sig_pro_plts = self.samdirname + '_sig_pro_plts'
        self.sig_gag_plts = self.samdirname + '_sig_gag_plts'
        self.pileup_name = self.samdirname + '_pileup'

        bam_extension = '_sorted'
        self.use_temp = use_temp

        if use_temp:
            if sam_path.endswith('.sam'):
                print "make bam...",
                bam_content = pysam.view('-bS', sam_path)

                # write BAM to a temp file
                self.temp_file = NamedTemporaryFile(delete=False)
                self.temp_filename = self.temp_file.name
                self.temp_file.writelines(bam_content)
                self.temp_file.close()

                # sort BAM file
                print "sort...",
                pysam.sort(self.temp_file.name, 
                           self.temp_file.name+bam_extension)
                print "index...",
                pysam.index('%s%s.bam' % (self.temp_file.name, bam_extension))
                print "make sam!"

                self.samfile = pysam.Samfile(self.temp_file.name
                                             + bam_extension +'.bam', 'rb')
            else:
                self.use_temp = False
                if sortn:
                    sorted_path = sam_path + '_nsorted'
                    if not os.path.exists(sorted_path):
                        print "sorting by query name"
                        pysam.sort('-n', sam_path, sorted_path)
                    self.samfile = pysam.Samfile(sorted_path, 'rb')
                else:
                    self.samfile = pysam.Samfile(sam_path, 'rb')
        else:
            print 'storing bam files'
            if sam_path.endswith('.sam'):
                print "make bam...",
                bam_content = pysam.view('-bS', sam_path)

                # write BAM to a temp file
                self.bam_file_name = self.samdirname+'.bam'
                self.bam_file = open(self.bam_file_name, 'w+')
                self.bam_file.writelines(bam_content)
                self.bam_file.close()

                # sort BAM file
                print "sort...",
                pysam.sort(self.bam_file_name, self.bam_file_name+bam_extension)
                print "index...",
                pysam.index('%s%s.bam' % (self.bam_file_name, bam_extension))
                print "make sam!"

                self.samfile = pysam.Samfile(self.bam_file_name
                                             + bam_extension +'.bam', 'rb')
            else:
                if sortn:
                    sorted_path = sam_path + '_nsorted'
                    if not os.path.exists('%s.bam'%sorted_path):
                        print "sorting by query name..."
                        pysam.sort('-n', sam_path, sorted_path)
                    self.samfilen = pysam.Samfile('%s.bam'%sorted_path, 'rb')
                self.samfile = pysam.Samfile(sam_path, 'rb')

    def __del__(self):
        if self.use_temp:
            os.unlink(self.temp_filename)
            os.unlink(self.temp_filename + '_sorted.bam')
            os.unlink(self.temp_filename + '_sorted.bam.bai')

    
    def sam_stats(self):
        mapped = self.samfile.mapped
        unmapped = self.samfile.unmapped
        total = float(mapped + unmapped)

        print 'filename, mapped, unmapped, percent mapped'
        print '%s, %d, %d, %.2f%% map'%(self.samfile.filename, mapped, unmapped,
                                        100 * mapped/total)


    def sam_coverage(self):
        # This process doesn't work properly because samtools limits the max 
        # read depth to 8000 (or so) reads. The pysam committers said it's 
        # samtools, not pysam, that's the problem.
        pileup_iter = self.samfile.pileup('CONSENSUS_B_GAG_POL', GAG_START, PRO_END, maxdepth=1e6)
        return [p.n for p in pileup_iter]

    def trim_read(self, read, start, end, codon=True):
        """
        M    BAM_CMATCH         0
        I    BAM_CINS           1
        D    BAM_CDEL           2
        N    BAM_CREF_SKIP      3
        S    BAM_CSOFT_CLIP     4
        H    BAM_CHARD_CLIP     5
        P    BAM_CPAD           6
        =    BAM_CEQUAL         7
        X    BAM_CDIFF          8
        """
    
        if read.pos > end or read.aend < start:
            if codon: 
                return '', 0
            else:
                return ''
    
        aligned_seq = ''
        read_pos = 0
        for code, n in read.cigar:
            if code == 7:
                raise Exception(KeyError, "Exact match?")
            if code == 0:
                aligned_seq += read.seq[read_pos:read_pos + n]
            if code == 1:
                pass
            if code == 2:
                aligned_seq += 'N' * n
                read_pos -= n
            if code == 3:
                raise Exception(KeyError, "This shouldn't happen...")
            if code == 4:
                pass
            if code == 5:
                pass
            read_pos += n
    
        trimmed_seq = aligned_seq
        l_offset = start - read.pos
        r_offset = read.pos + len(aligned_seq) - end
        frame_offset = 0
        if l_offset > 0:
            trimmed_seq = trimmed_seq[l_offset:]
        if r_offset > 0:
            trimmed_seq = trimmed_seq[0:len(trimmed_seq) - r_offset]
        if not codon:
            return trimmed_seq

        if l_offset < 0:
            frame_offset = (start - read.pos) % 3
        return trimmed_seq, frame_offset


    def translate_read(self, read, start, end):
        trimmed_read, offset = self.trim_read(read, start, end, codon=True)
        prot_seq = translate(trimmed_read[offset:])
    
        prot_start = (read.pos + offset - start) / 3
        if prot_start < 0:
            prot_start = 0
    
        return prot_seq, prot_start

#------------------------------------------------------------------------------

    def _dNdS_sites(self, codon):
        syn = 0
        non = 0
        alphabet = 'ACGT'
        aa = translate(codon)
        if len(codon) < 3: return syn, non

        for i in range(3):
            for mut in alphabet:
                if mut == codon[i]: continue
                mut_codon = codon[:i] + mut + codon[i+1:]

                syn_flag = (aa == translate(mut_codon))
                syn += syn_flag
                non += (not syn_flag)
        
        syn /= 3.
        non /= 3.
        assert syn + non == 3
        return syn, non

    def dNdS(self, reference, start, end):
        ps, pn = 0, 0
        n_codon = (end - start) / 3
        for i in range(start, end, 3):
            ref_codon = reference[i-start:i-start+3]
            ref_aa = translate(ref_codon)
            s_i, n_i = self._dNdS_sites(ref_codon)
            if s_i == 0: continue
            
            m_i = 0
            inner_s, inner_n = 0, 0
            reads = self.samfile.fetch('CONSENSUS_B_GAG_POL', start, end)
            for read in reads:
                trimmed_read = self.trim_read(read, i, i+3, codon=False)
                if len(trimmed_read) < 3: continue

                m_i += 1
                cur_pos = read.pos - i
                if cur_pos < 0: 
                    cur_pos = 0

                sij, nij = 0, 0
                for j, nt in enumerate(trimmed_read):
                    if nt == ref_codon[j]: continue
                    mut_codon = ref_codon[:j] + nt + ref_codon[j+1:]
                    if translate(mut_codon) == ref_aa:
                        sij += 1
                    else:
                        nij += 1
            
                inner_s += sij / s_i
                inner_n += nij / n_i
            
            ps += inner_s / m_i
            pn += inner_n / m_i

        ps /= float(n_codon)
        pn /= float(n_codon)

        ds = -.75 * np.log(1 - 4*ps/3)
        dn = -.75 * np.log(1 - 4*pn/3)
        print ds/dn

#------------------------------------------------------------------------------

    def nucleotide_counts(self, reference, start, end):
        reads = self.samfile.fetch('CONSENSUS_B_GAG_POL', start, end)

        mutations = []
        for nt in reference:
            mutations.append(dict(zip(  ('ref','A','C','G','T','N'), 
                                        (nt, 0, 0, 0, 0, 0))))

        for read in reads:
            trimmed_read = self.trim_read(read, start, end, codon=False)
            if trimmed_read == '': continue
            cur_pos = read.pos - start
            if cur_pos < 0:
                cur_pos = 0 
            for nt in trimmed_read:
                if nt not in ('A', 'C', 'G', 'T', 'N'):
                    pass
                else:
                    mutations[cur_pos][nt] = mutations[cur_pos].get(nt, 0) + 1
                cur_pos += 1

        return mutations

    def protein_counts(self, reference, start, end):
        reads = self.samfile.fetch('CONSENSUS_B_GAG_POL', start, end)
        
        mutations = []
        for aa in reference:
            d = dict(zip(   ('ACDEFGHIKLMNPQRSTVWYX*'),
                            np.zeros(22)))
            d['ref'] = aa
            mutations.append(d)

        for read in reads:
            trans_read, trans_start = self.translate_read(read, start, end)
            if trans_read == '': continue
            cur_pos = trans_start

            for aa in trans_read:
                mutations[cur_pos][aa] = mutations[cur_pos].get(aa, 0) + 1
                cur_pos += 1

        return mutations

    def export_reads(self, reference, start, end):
        reads = self.samfile.fetch('CONSENSUS_B_GAG_POL', start, end)

        with open(self.pro_reads, 'w') as outf:
            for read in reads:
                trans_read, trans_start = self.translate_read(read, start, end)
                if trans_read == '': continue
                outf.write('%d,%s\n'%(trans_start, trans_read))

    def export_sequences2(self, reference, start, end):
        def write_se(read, start, end, outf):
            L = (end - start) / 3
            read1, start1 = self.translate_read(read, start, end)
            if read1 == '': return
            len1 = len(read1)
            seq = '.'*start1 + read1 + '.'*(L-len1-start1)
            outf.write("%s\n"%seq)

        def write_pe(read, mate, start, end, outf):
            L = (end - start)/3
            if read.qname != mate.qname:
                write_se(read, start, end, outf)
                write_se(mate, start, end, outf)

            read1, s1 = self.translate_read(read, start, end)
            read2, s2 = self.translate_read(mate, start, end)
            if read1 == '' and read2 == '': return 
            if s1 > s2:
                read1, read2 = read2, read1
                s1, s2 = s2, s1

            len1, len2 = len(read1), len(read2)
            if s2 >= s1 + len1:
                if s2 > L: s2 = L
                seq = '.'*s1 + read1 + '.'*(s2-len1-s1) + read2 + '.'*(L-len2-s2)
            else:
                seq = '.'*s1 + read1[:s2-s1] + read2
                seq += '.'*(L-len(seq))
            outf.write("%s\n"%seq)

        #L = (end - start) / 3

        #count = 0
        #found_mate = True
        mate1, mate2 = None, None
        with open(self.pro_reads, 'w') as outf:
            for read in self.samfilen:
                if not read.is_proper_pair or not read.is_paired:
                    write_se(read, start, end, outf)
                elif read.is_proper_pair and read.is_read1:
                    if mate1 is not None: write_se(mate1, start, end, outf)
                    mate1 = read
                elif read.is_proper_pair and read.is_read2:
                    mate2 = read
                    if mate1 and mate2:
                        write_pe(mate1, mate2, start, end, outf)
                    else:
                        write_se(mate2, start, end, outf)
                    mate1, mate2 = None, None

                ## get read1
                ## if previous read1 didn't get a pair, write it out
                #if read.is_proper_pair and read.is_read1:
                #    if not found_mate:
                #        found_mate = True
                #        if read1 == '': continue
                #        len1 = len(read1)
                #        seq = '.'*start1 + read1 + '.'*(L-len1-start1)
                #        outf.write('%s\n'%seq)
                #    else:
                #        read1, start1 = self.translate_read(read, start, end)
                #        found_mate = False
                #        continue
                ## get read2
                #elif read.is_proper_pair and read.is_read2:
                #    found_mate = True
                #    read2, start2 = self.translate_read(read, start, end)
                #    if read1 == '' and read2 == '': continue
                #    if start2 < start1:
                #        read1, read2 = read2, read1
                #        start1, start2 = start2, start1

                #    len1 = len(read1)
                #    len2 = len(read2)
                #    # read2 is separated from read 1
                #    if start2 >= start1 + len1:
                #        if start2 > L: start2 = L
                #        seq = '.'*start1 + read1 + '.'*(start2-len1-start1) +\
                #              read2 + '.'*(L-len2-start2)
                #    # read2 and read1 overlap
                #    else:
                #        seq = '.'*start1 + read1[:start2-start1] + read2
                #        seq += '.'*(L-len(seq))
                #    
                #    outf.write('%s\n'%seq)

                #elif not read.is_proper_pair:
                #    read1, start1 = self.translate_read(read, start, end)
                #    found_mate = True
                #    if read1 == '': continue
                #    len1 = len(read1)
                #    seq = '.'*start1 + read1 + '.'*(L-len1-start1)
                #    
                #outf.write('%s\n'%seq)

    def export_sequences(self, reference, start, end):
        reads = self.samfile.fetch('CONSENSUS_B_GAG_POL', start, end, until_eof=1)
        L = (end - start) / 3

        count = 0
        with open(self.pro_reads, 'w') as outf:
            for read in reads:
                # incorporate paired reads
                if read.is_proper_pair and read.is_read1:
                    pointer = self.samfile.tell() # save current position
                    try: mate = self.samfile.mate(read)
                    except ValueError: continue
                    finally:
                        self.samfile.seek(pointer)
                        read1, start1 = self.translate_read(read, start, end)
                        read2, start2 = self.translate_read(mate, start, end)
                        
                        if start2 < start1:
                            read1, read2 = read2, read1
                            start1, start2 = start2, start1

                        len1 = len(read1)
                        len2 = len(read2)
                            
                        seq = '.'*start1 + read1 + '.'*(start2-len1-start1) +\
                              read2 + '.'*(L-len2-start2)
                    
                    outf.write('%s\n'%seq)
                    count += 1
                    if count%1000==0: print count

                elif not read.is_proper_pair:
                    read1, start1 = self.translate_read(read, start, end)
                    if read1 == '': continue
                    len1 = len(read1)
                    seq = '.'*start1 + read1 + '.'*(L-len1)
                    
                    outf.write('%s\n'%seq)
                    count += 1
                    if count%1000==0: print count

    def cwr(iterable, r):
        # combinations_with_replacement (itertools 2.7 generator)
        pool = tuple(iterable)
        n = len(pool)
        for indices in product(range(n), repeat=r):
            if sorted(indices) == list(indices):
                yield typle(pool[i] for i in indices)

    def protein_pair_counts(self, reference, start, end):
        # THIS SUCKS AND IS WAY TOO SLOW
        reads = self.samfile.fetch('CONSENSUS_B_GAG_POL', start, end)

        mutations = []
        all_aas = 'ACDEFGHIKLMNPQRSTVQYX*'
        possible_combos = [''.join(aas) for aas in product(all_aas, repeat=2)]
        for aa_pair in combinations(reference, 2):
            d = dict(zip(possible_combos, [0]*len(possible_combos)))
            d['ref'] = ''.join(aa_pair)
            mutations.append(d)

        for read in reads:
            # If its read one and part of a pair, try to get bivariates from 
            # pair
            if read.is_proper_pair and read.is_read1:
                pointer = self.samfile.tell()
                try:
                    mate = self.samfile.mate(read)
                except ValueError:
                    continue
                finally:
                    self.samfile.seek(pointer)

                    read1, start1 = self.translate_read(read, start, end)
                    read2, start2 = self.translate_read(mate, start, end)
                    if read1 == '' or read2 == '':
                        pass#continue

                    # Ensure read1 starts before read2
                    if start2 < start1:
                        swpread, swpstart = read2, start2
                        read2, start2 = read1, start1
                        read1, start1 = swpread, swpstart

                    cur_pos = len([j for i in range(start1) 
                                     for j in range(len(reference)-i)])\
                                     + start2

                    for i in range(len(read1)):
                        for j in range(min(i+1, start2), len(read2)):
                            pair = read1[i] + read2[j]
                            mutations[cur_pos][pair] = mutations[cur_pos].get(
                                                        pair, 0) + 1
                            cur_pos += 1
                        cur_pos += len(reference) - i + start2 
            
            # Regardless of what read it is, we want the bivariates from just
            # the single read. The mate to this read will get its turn when
            # it fails the is_proper_pair/is_read1 if above. This catches reads
            # that are unpaired. 
            read1, start1 = self.translate_read(read, start, end)
            cur_pos = len([j for i in range(start1) 
                             for j in range(len(reference)-i)])

            for i in range(len(read1)):
                for j in range(i+1, len(read1)):
                    pair = read1[i] + read1[j]
                    mutations[cur_pos][pair] = mutations[cur_pos].get(pair,
                                                                        0) + 1
                    cur_pos += 1
                cur_pos += len(reference) - i
            
        return mutations


#------------------------------------------------------------------------------


    def get_local_codon(self, mutations, pos, mut=None):
        codon_pos = pos%3
        codon_seq = mutations[pos - codon_pos:pos + 3 - codon_pos]
        codon_seq = ''.join(map(lambda _x: _x['ref'], codon_seq))
        if mut:
            codon_seq = codon_seq[:codon_pos]+ mut + codon_seq[codon_pos + 1:]
        return translate(codon_seq)

    def export_nucleotide_frequencies(self, mutations, outfilename, start,
                                      threshold=None):
        outputfile = open(outfilename, 'w')
        writer = csv.writer(outputfile)

        sig_freqs = []

        N = len(mutations)
        for pos in range(N):
            pos_info = mutations[pos].copy()
            pos_ref = pos_info.pop('ref')
            pos_nts = pos_info.items()
            
            # synonymous mutation info
            ref_codon = self.get_local_codon(mutations, pos)
            
            total = float(sum(pos_info.values()))
            if total == 0: total = 1.
            for nt, count in pos_nts:
                freq = count/total
                wt_flag = int(nt == pos_ref)
                #if not threshold:
                writer.writerow([pos, nt, count, "%.8f"%freq, wt_flag])
                #else:
                #    if wt_flag == False and freq > threshold:
                #        mut_codon = self.get_local_codon(mutations, pos, mut=nt)
                #        syn_flag = int(mut_codon == ref_codon)
                #        writer.writerow([pos+start, nt, count, "%.8f"%freq, 
                #                        'ref:'+pos_ref,
                #                        'aa:%s->%s'%(ref_codon, mut_codon)])
                #        sig_freqs.append({'pos': pos, 'freq':freq,
                #                          'mut_nt':nt, 'mut_aa':mut_codon,
                #                          'ref_nt':pos_ref, 'ref_aa':ref_codon})
        outputfile.close()
        return sig_freqs

    def export_amino_acid_frequencies(self, mutations, outfilename, start):
        print "exporting amino acids to %s"%outfilename
        outputfile = open(outfilename, 'wb')
        writer = csv.writer(outputfile)
        bad_codons = set(['*', 'X'])

        #sig_freqs []
        N = len(mutations)
        for pos in range(N):
            pos_info = mutations[pos].copy()
            pos_ref = pos_info.pop('ref')
            pos_aas = pos_info.items()

            total = float(sum(pos_info.values()))
            if total == 0: total = 1.
            for aa, count in pos_aas:
                freq = count/total
                wt_flag = int(aa == pos_ref)
                writer.writerow([pos, aa, count, "%.8f"%freq, wt_flag])
        outputfile.close()

    def export_amino_acid_pair_frequencies(self, mutations, outfilename):
        print "exporting amino acid pair counts to %s"%outfilename
        outputfile = open(outfilename, 'wb')
        writer = csv.writer(outputfile)

        N = len(mutations)
        for pos in range(N):
            pos_info = mutations[pos].copy()
            pos_ref = pos_info.pop('ref')
            pos_ass = pos_info.items()

            total = float(sum(pos_info.values()))
            if total < 1: total = 1.
            for aa, count in pos_aas:
                freq = count/total
                wt_flag = int(aa == pos_ref)
                writer.writerow([pos, aa, count, "%.8f"%freq, wt_flag])
        outputfile.close()

    def export_protease_reads(self):
        self.export_reads(PRO_AA_REFERENCE, PRO_START, PRO_END)
    def export_protease_sequences(self):
        self.export_sequences2(PRO_AA_REFERENCE, PRO_START, PRO_END)

#------------------------------------------------------------------------------

    def plot_nucleotide_frequencies(self, outfile, mutations, start, end):
        import matplotlib.pyplot as plt
        N = len(mutations)
        nt_index = dict(zip("ACGTN", range(5)))
        nt_counts = [[] for _ind in nt_index]

        for pos in range(N):
            pos_info = mutations[pos]
            pos_ref = pos_info.pop('ref')
            pos_nts = pos_info.items()
            
            total = float(sum(pos_info.values()))
            if total == 0: total = 1.
            for nt, count in pos_nts:
                freq = count/total
                nt_counts[nt_index[nt]].append(freq)
        nt_counts = np.array(nt_counts)

        N_plts = np.round((end-start)/150.)
        for z in range(int(N_plts)):
            fig = plt.figure(figsize=(18.7,10.5))
            ax = fig.add_subplot(111)
            plt.title('Nucleotide Frequencies of Patient %s'%self.patientid)

            l_edge = z*N/N_plts
            r_edge = (z+1)*N/N_plts
            x = np.arange(N)[l_edge:r_edge]
            w = 1
            c = nt_counts[:, l_edge:r_edge]

            plt_a = plt.bar(x, c[0], w, color='g', align='center', alpha=0.5)
            plt_c = plt.bar(x, c[1], w, color='b', align='center', alpha=0.5,
                            bottom=c[:1].sum(0))
            plt_g = plt.bar(x, c[2], w, color='Gold', align='center', alpha=0.5,
                            bottom=c[:2].sum(0))
            plt_t = plt.bar(x, c[3], w, color='r', align='center', alpha=0.5,
                            bottom=c[:3].sum(0))
            plt_n = plt.bar(x, c[4], w, color='k', align='center', alpha=0.5,
                            bottom=c[:4].sum(0))

            plt.legend( (plt_a[0], plt_c[0], plt_g[0], plt_t[0], plt_n[0]), 
                        list('ACGTN'), bbox_to_anchor=(1,1), loc=2)

            plt.ylabel('Nucleotide Frequency')
            plt.xlabel('Gag-Pol sequence position')
            plt.xlim([l_edge -1, r_edge])
            plt.ylim([0, 1])


            locs, labels = plt.xticks()
            plt.xticks(locs, map(lambda x: "%i"%(x+start), locs))
            plt.xlim([l_edge -1, r_edge])

            #mng = plt.get_current_fig_manager()
            #mng.window.maximize()
            fig.savefig(outfile+'%i.png'%z, orientation='landscape')
            #plt.show()
       
       #for nt, ind in nt_index.items():
        #    plt.subplot(511+int(ind))
        #    g = plt.bar(range(N), nt_counts[ind],
        #                linewidth=0, align='center')
        #    lines.append(g)
        #    legend_nts.append(nt)
        #    plt.ylabel('%s Frequency'%nt)
        #    plt.axis([0,50, 0,0.2])
        #plt.xlabel('Gag-Pol Sequence Position')
        #plt.show()
        #fig.savefig('testfig.png', dpi=200)


    def plot_significant_mutations(self, outfile, sig_muts, start, end):
        import matplotlib.pyplot as plt
        x = np.arange(start, end)

        fig = plt.figure(figsize=(18.7,10.5))
        ax1 = fig.add_subplot(111)

        color_table = zip('ACGTN', ('g','b','Gold','r','k'))
        for color_nt, color in color_table:
            y = np.zeros(end-start)
            syn = []
            for mut in sig_muts:
                pos, freq, nt = (mut['pos'], mut['freq'], mut['mut_nt'])
                syn_mut = (mut['ref_aa'] == mut['mut_aa'])
                if nt == color_nt:
                    y[pos] = freq
                    syn.append((pos, syn_mut))

            b = ax1.bar(x, y, width=1, align='center', alpha=0.5, color=color,
                            label=color_nt)
            for i, s in syn:
                if not s:
                    b[i].set_hatch('/')

        ax1.legend(bbox_to_anchor=(1,1), loc=2)
        plt.xlim(start,end)

        plt.title('Significant Nucleotide Mutations of Patient %s'
                    %self.patientid)
        plt.xlabel('Gag-Pol sequence position')
        plt.ylabel('Mutational Frequency')
        fig.savefig(outfile + '.png', orientation='landscape')
#        plt.show()


    def analyze_protease_nucleotides(self):
        counts = self.nucleotide_counts(PRO_NT_REFERENCE, PRO_START, PRO_END)
        sig_muts = self.export_nucleotide_frequencies(counts, 
                                                      self.pro_counts+'_nt', 
                                                      PRO_START)#,threshold=.05)
        #self.plot_nucleotide_frequencies(self.pro_plts, counts, 
        #                                 PRO_START, PRO_END)
        #self.plot_significant_mutations(self.sig_pro_plts, sig_muts, 
        #                                PRO_START, PRO_END)

    def analyze_gag_nucleotides(self):
        counts = self.nucleotide_counts(GAG_NT_REFERENCE, GAG_START, GAG_END)
        sig_muts = self.export_nucleotide_frequencies(counts, 
                                                      self.gag_counts+'_nt',
                                                      GAG_START, threshold=0.05)
        #self.plot_nucleotide_frequencies(self.gag_plts, counts, 
        #                                 GAG_START, GAG_END)
        #self.plot_significant_mutations(self.sig_gag_plts, sig_muts,
        #                                GAG_START, GAG_END)

    def analyze_genome_nucleotides(self):
        self.analyze_gag_nucleotides()
        self.analyze_protease_nucleotides()

    def analyze_protease_amino_acids(self):
        counts = self.protein_counts(PRO_AA_REFERENCE, PRO_START, PRO_END)
        pair_counts = self.protein_pair_counts(PRO_AA_REFERENCE, PRO_START, 
                                               PRO_END)
        self.export_amino_acid_frequencies(counts, self.pro_counts+'_aa', 
                                           PRO_START)
        self.export_amino_acid_pair_frequencies(pair_counts, self.pro_pair_counts+'_aa', 
                                                PRO_START)

    def analyze_protease_amino_acid_pairs(self):
        print "analyzing"
        pair_counts = self.protein_pair_counts(PRO_AA_REFERENCE, PRO_START,
                                               PRO_END)
        self.export_amino_acid_pair_frequencies(pair_counts, self.pro_pair_counts+'_aa',
                                                PRO_START)
        print "done analyzing"


    def analyze_gag_amino_acids(self):
        counts = self.protein_counts(GAG_AA_REFERENCE, GAG_START, GAG_END)
        self.export_amino_acid_frequencies(counts, self.gag_counts+'_aa',
                                           GAG_START)

    def analyze_all(self):
        pro_aa = self.pro_counts+'_aa'
        pro_nt = self.pro_counts+'_nt'
        gag_aa = self.gag_counts+'_aa'
        gag_nt = self.gag_counts+'_nt'
        #pro
        counts = self.nucleotide_counts(PRO_NT_REFERENCE, PRO_START, PRO_END)
        self.export_nucleotide_frequencies(counts, pro_nt, PRO_START)
        counts = self.protein_counts(PRO_AA_REFERENCE, PRO_START, PRO_END)
        self.export_amino_acid_frequencies(counts, pro_aa, PRO_START)
        #gag
        counts = self.nucleotide_counts(GAG_NT_REFERENCE, GAG_START, GAG_END)
        self.export_nucleotide_frequencies(counts, gag_nt, GAG_START)
        counts = self.protein_counts(GAG_AA_REFERENCE, GAG_START, GAG_END)
        self.export_amino_acid_frequencies(counts, gag_aa, GAG_START)

    def analyze_genome_amino_acids(self):
        self.analyze_gag_amino_acids()
        self.anaylze_protease_amino_acids()

    def dNdS_pro(self):
        self.dNdS(PRO_NT_REFERENCE, PRO_START, PRO_END)

    def dNdS_gag(self):
        self.dNdS(GAG_NT_REFERENCE, GAG_START, GAG_END)
except KeyError:
    wineprefix = os.path.join(os.environ['HOME'],'.wine')
winetemp = os.path.join(wineprefix, 'drive_c','windows','temp')
f = NamedTemporaryFile(prefix = 'winecolors', suffix = '.reg', dir = winetemp, mode = 'w+')
f.write("""REGEDIT4

[HKEY_CURRENT_USER\Control Panel]

[HKEY_CURRENT_USER\Control Panel\Colors]
""")

# Alphabetize list (purely so that user.reg is easy to read; Wine doesn't care)
color_pairs = sorted(color_pairs)

# Append list to file, with newlines
f.writelines(line + '\n' for line in color_pairs)
f.flush()

# Debugging
if debug_mode:
    print '---- [' + f.name + '] ----'
    f.seek(0)
    for line in f:
        print line,
    print '--------\n'

# Import values into Wine registry using regedit command
print('Using regedit to import colors into registry...\n')
os.system('regedit ' + f.name)
# TODO: Check if this worked correctly.