Exemple #1
0
 def roundtrip(self, dtype, x):
     f = NamedTemporaryFile(suffix='.tif')
     fname = f.name
     f.close()
     imsave(fname, x)
     y = imread(fname)
     assert_array_equal(x, y)
 def _get_tmp_file(self):
     ''' Creates a tmp file with the file data '''
     data = self.request.files['xml_file'][0]['body']
     tmp_file = NamedTemporaryFile(delete=False)
     tmp_file.write(data)
     tmp_file.close()
     return tmp_file.name
Exemple #3
0
    def compile_inline(self,data,ext):
        """
        Compile inline css. Have to compile to a file, because some css compilers
        may not output to stdout, but we know they all output to a file. It's a
        little hackish, but you shouldn't be compiling in production anyway,
        right?
        """
        compiler = settings.COMPILER_FORMATS[ext]
        try:
            bin = compiler['binary_path']
        except:
            raise Exception("Path to CSS compiler must be included in COMPILER_FORMATS")
        
        tmp_file = NamedTemporaryFile(mode='w',suffix=ext)
        tmp_file.write(dedent(data))
        tmp_file.flush()
        path, ext = os.path.splitext(tmp_file.name)
        tmp_css = ''.join((path,'.css'))
        
        self.compile(path,compiler)
        data = open(tmp_css,'r').read()
        
        # cleanup
        tmp_file.close()
        os.remove(tmp_css)

        return data  
Exemple #4
0
def test_BeagleOrderMulti():

    from vsm.util.corpustools import random_corpus
    from vsm.model.beagleenvironment import BeagleEnvironment

    c = random_corpus(1000, 50, 0, 20, context_type='sentence')

    e = BeagleEnvironment(c, n_cols=100)
    e.train()

    m = BeagleOrderMulti(c, e.matrix)
    m.train(4)

    from tempfile import NamedTemporaryFile
    import os

    try:
        tmp = NamedTemporaryFile(delete=False, suffix='.npz')
        m.save(tmp.name)
        tmp.close()
        m1 = BeagleOrderMulti.load(tmp.name)
        assert (m.matrix == m1.matrix).all()
    
    finally:
        os.remove(tmp.name)

    return m.matrix
Exemple #5
0
	def run(self):
		active_view = self.window.active_view()
		text = "\n\n".join(getSelectedText(active_view)).strip()

		tf = NamedTemporaryFile(mode="w", delete=False)
		try:
			tf.write(text)
			tf.close()

			res = subprocess.check_output(["m4", tf.name],
			                              stderr=subprocess.STDOUT,
			                              cwd=os.path.dirname(os.path.abspath(active_view.file_name())))
			res = res.decode('utf-8').replace('\r', '').strip()

			panel_name = "m4expand.results"
			panel = self.window.create_output_panel(panel_name)
			self.window.run_command("show_panel", {"panel": "output." + panel_name})

			panel.set_read_only(False)
			panel.set_syntax_file(active_view.settings().get("syntax"))
			panel.run_command("append", {"characters": res})
			panel.set_read_only(True)
		except Exception as e:
			print("M4Expand - An error occurred: ", e)
		finally:
			os.unlink(tf.name)
Exemple #6
0
def get_body():
    body = ""

    # Create a temporary file
    body_buffer_file = NamedTemporaryFile(delete=False)
    body_buffer_file_path = body_buffer_file.name
    body_buffer_file.close()

    # Set the default editor
    editor = 'nano'
    if os.name is 'nt':
        editor = 'notepad'

    raw_input('Press Enter to start writing the body of the mail')
    try:
        subprocess.call([editor, body_buffer_file_path])
    except OSError:
        # No suitable text editor found
        # Let the user edit the buffer file himself
        print "Enter the mail body in the file located at '" + body_buffer_file_path + "'"
        raw_input("Press Enter when done!")

    body_buffer_file = open(body_buffer_file_path)
    body = body_buffer_file.read()
    body_buffer_file.close()
    try:
        os.remove(body_buffer_file_path)
    except:
        # Unable to remove the temporary file
        # Stop the exception from propogating further,
        # since removing it is not essential to the working of the program
        pass

    return body
def apicalls(target, **kwargs):
    """
    """
    if not target:
        raise Exception("Invalid target for apicalls()")

    output_file = NamedTemporaryFile()
    kwargs.update({"output_file" : output_file})
    cmd = _dtrace_command_line(target, **kwargs)

    # Generate dtrace probes for analysis
    definitions = os.path.abspath(os.path.join(__file__, "../../core/data/signatures.yml"))
    probes_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "probes.d")
    generate_probes(definitions, probes_file, overwrite=True)

    # The dtrace script will take care of timeout itself, so we just launch
    # it asynchronously
    with open(os.devnull, "w") as null:
        _ = Popen(cmd, stdout=null, stderr=null, cwd=current_directory())

    with open('/Users/cloudmark/yield.txt', 'w+') as f:
        for entry in filelines(output_file):
            value = entry.strip()
            if "## apicalls.d done ##" in value:
                break
            if len(value) == 0:
             continue
            f.write(str(_parse_entry(value)))
            f.flush()
            import time
            time.sleep(1)
            yield _parse_entry(value)
    output_file.close()
    os.remove(probes_file)
 def testLogfile(self):
     """Test logging into a logfile"""
     f = NamedTemporaryFile(delete=False)
     filename = f.name
     try:
         set_log_level("error")  # avoid using the console logger
         f.write(":-P\n")
         f.close()
         start_logfile(f.name, "devinfo")
         log = getLogger("prosoda.test.integration.test_logger")
         log.debug("Should not be in logfile! :-( ")
         log.info("Should be in logfile :-) ")
         log.devinfo("Should be in logfile :-) ")
         log.warning("Should really be in logfile :-D ")
         stop_logfile(f.name)
         contents = file(f.name).read()
         self.assertNotIn(":-(", contents)
         self.assertNotIn(":-P", contents)
         self.assertIn(":-)", contents)
         self.assertIn(":-D", contents)
         # Make sure no colour codes are leaked into the logfile
         self.assertNotIn("\033", contents)
     finally:
         set_log_level("debug")
         unlink(filename)
Exemple #9
0
def ipconnections(target, **kwargs):
    """Returns a list of ip connections made by the target.

    A connection is a named tuple with the following properties:
    host (string), host_port (int), remote_port (string), protocol (string),
    timestamp(int).
    """
    if not target:
        raise Exception("Invalid target for ipconnections()")

    output_file = NamedTemporaryFile()
    cmd = ["sudo", "/usr/sbin/dtrace", "-C"]
    if "timeout" in kwargs:
        cmd += ["-DANALYSIS_TIMEOUT=%d" % kwargs["timeout"]]
    cmd += ["-s", path_for_script("ipconnections.d")]
    cmd += ["-o", output_file.name]
    if "args" in kwargs:
        line = "%s %s" % (sanitize_path(target), " ".join(kwargs["args"]))
        cmd += ["-c", line]
    else:
        cmd += ["-c", sanitize_path(target)]

    # The dtrace script will take care of timeout itself, so we just launch
    # it asynchronously
    with open(os.devnull, "w") as f:
        handler = Popen(cmd, stdout=f, stderr=f)

    for entry in filelines(output_file):
        if "## ipconnections.d done ##" in entry.strip():
            break
        yield _parse_single_entry(entry.strip())
    output_file.close()
Exemple #10
0
        def odt_subreport(name=None, obj=None):
            if not aeroo_ooo:
                return _("Error! Subreports not available!")
            report_xml_ids = ir_obj.search(cr, uid, [('report_name', '=', name)], context=context)
            if report_xml_ids:
                service = netsvc.Service._services['report.%s' % name]
                report_xml = ir_obj.browse(cr, uid, report_xml_ids[0], context=context)
                data = {'model': obj._table_name, 'id': obj.id, 'report_type': 'aeroo', 'in_format': 'oo-odt'}
                ### Get new printing object ###
                sub_aeroo_print = AerooPrint()
                service.active_prints[sub_aeroo_print.id] = sub_aeroo_print
                context['print_id'] = sub_aeroo_print.id
                ###############################
                sub_aeroo_print.start_time = time.time()
                report, output = service.create_aeroo_report(cr, uid, \
                                                             [obj.id], data, report_xml, context=context,
                                                             output='odt')  # change for OpenERP 6.0 - Service class usage

                ### Delete printing object ###
                AerooPrint.print_ids.remove(sub_aeroo_print.id)
                del service.active_prints[sub_aeroo_print.id]
                ##############################
                temp_file = NamedTemporaryFile(suffix='.odt', prefix='aeroo-report-', delete=False)
                try:
                    temp_file.write(report)
                finally:
                    temp_file.close()
                # self.oo_subreports[print_id].append(temp_file.name)
                # aeroo_print.subreports.append(temp_file.name)
                self.active_prints[aeroo_print.id].subreports.append(temp_file.name)
                return "<insert_doc('%s')>" % temp_file.name
            return None
Exemple #11
0
def tmpfile(stream, mode=None):
    """Context manager that writes a :class:`Stream` object to a named
    temporary file and yield it's filename. Cleanup deletes from the temporary
    file from disk.

    Args:
        stream (Stream): Stream object to write to disk as temporary file.
        mode (int, optional): File mode to set on temporary file.

    Returns:
        str: Temporoary file name
    """
    tmp = NamedTemporaryFile(delete=False)

    if mode is not None:
        oldmask = os.umask(0)

        try:
            os.chmod(tmp.name, mode)
        finally:
            os.umask(oldmask)

    for data in stream:
        tmp.write(to_bytes(data))

    tmp.close()

    yield tmp.name

    os.remove(tmp.name)
    def render_to_temporary_file(self, template_name, mode='w+b', bufsize=-1,
                                 suffix='.html', prefix='tmp', dir=None,
                                 delete=True):
        template = self.resolve_template(template_name)

        context = self.resolve_context(self.context_data)

        content = smart_str(template.render(context))
        content = make_absolute_paths(content)

        try:
            tempfile = NamedTemporaryFile(mode=mode, bufsize=bufsize,
                                      suffix=suffix, prefix=prefix,
                                      dir=dir, delete=delete)
        except TypeError:
            tempfile = NamedTemporaryFile(mode=mode, buffering=bufsize,
                                      suffix=suffix, prefix=prefix,
                                      dir=dir, delete=delete)

        try:
            tempfile.write(content)
            tempfile.flush()
            return tempfile
        except TypeError:
            tempfile.write(bytes(content, 'UTF-8'))
            tempfile.flush()
            return tempfile
        except:
            # Clean-up tempfile if an Exception is raised.
            tempfile.close()
            raise
    def run_solver(self, conflicts, election, deletion_handler, outfile=None):
        if not conflicts:
            return [], 0

        self.deletion_handler = deletion_handler

        instance = self.generate_instance(conflicts, election)

        f = NamedTemporaryFile(delete=False)
        f.write(instance.encode(code))
        f.close()

        process = Popen([self.cmd, f.name], stdout=PIPE)
        out, err = process.communicate()

        conflict_variables, optimum = self.parse_instance(out)

        if outfile:
            candidates = election[0]
            votes = election[1]
            votecounts = election[2]

            votemap = self.delete_votes(votes, votecounts, conflict_variables)
            votesum = sum(votemap.values())

            write_map(candidates, votesum, votemap, open(outfile, "w"))

        remove(f.name)
        return conflict_variables, optimum
    def __enter__(self):
        # Ensure that we have not re-entered
        if self.temp_path != None or self.service != None:
            raise Exception('Cannot use multiple nested with blocks on same Youtube object!')

        flow = flow_from_clientsecrets(
            self.client_secrets_path,
            scope=YOUTUBE_UPLOAD_SCOPE,
            message=MISSING_CLIENT_SECRETS_MESSAGE)

        temp_file = NamedTemporaryFile(delete=False)
        self.temp_path = temp_file.name
        temp_file.close()

        storage = Storage(self.temp_path)
        credentials = storage.get()

        if credentials is None or credentials.invalid:
            credentials = run_flow(
                flow, storage, argparser.parse_args(list())
            )

        self.service = build(YOUTUBE_API_SERVICE_NAME, YOUTUBE_API_VERSION,
            http=credentials.authorize(httplib2.Http()))

        return self
Exemple #15
0
def test_cmyk():
    ref = imread(os.path.join(data_dir, 'color.png'))

    img = Image.open(os.path.join(data_dir, 'color.png'))
    img = img.convert('CMYK')

    f = NamedTemporaryFile(suffix='.jpg')
    fname = f.name
    f.close()
    img.save(fname)
    try:
        img.close()
    except AttributeError:  # `close` not available on PIL
        pass

    new = imread(fname)

    ref_lab = rgb2lab(ref)
    new_lab = rgb2lab(new)

    for i in range(3):
        newi = np.ascontiguousarray(new_lab[:, :, i])
        refi = np.ascontiguousarray(ref_lab[:, :, i])
        sim = ssim(refi, newi, dynamic_range=refi.max() - refi.min())
        assert sim > 0.99
    def _create_empty_image(self, image_width, image_height):

        # Check pycairo capabilities
        if not (cairo.HAS_IMAGE_SURFACE and cairo.HAS_PNG_FUNCTIONS):
            raise HTTPBadRequest("cairo was not compiled with ImageSurface and PNG support")

        # Create a new cairo surface
        surface = cairo.ImageSurface(cairo.FORMAT_ARGB32, int(image_width), int(image_height))

        ctx = cairo.Context(surface)

        text = "No imagery available for requested coordinates."

        x_bearing, y_bearing, width, height, x_advance, y_advance = ctx.text_extents(text)

        ctx.move_to((image_width / 2) - (width / 2), (image_height / 2) + (height / 2))
        ctx.set_source_rgba(0, 0, 0, 0.85)
        ctx.show_text(text)

        temp_datadir = self.config.get("main", "temp.datadir")
        temp_url = self.config.get("main", "temp.url")
        file = NamedTemporaryFile(suffix=".png", dir=temp_datadir, delete=False)
        surface.write_to_png(file)
        file.close()

        return {"file": "%s/%s" % (temp_url, file.name.split("/")[-1])}
Exemple #17
0
    def run(self, uid, aid, publish=True):
        aid = int(aid)
        audiobook = Audiobook.objects.get(id=aid)
        self.set_status(aid, status.ENCODING)

        user = User.objects.get(id=uid)

        try:
            os.makedirs(BUILD_PATH)
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        out_file = NamedTemporaryFile(delete=False, prefix='%d-' % aid, suffix='.%s' % self.ext, dir=BUILD_PATH)
        out_file.close()
        self.encode(audiobook.source_file.path, out_file.name)
        self.set_status(aid, status.TAGGING)
        self.set_tags(audiobook, out_file.name)
        self.set_status(aid, status.SENDING)

        if publish:
            self.put(user, audiobook, out_file.name)
            self.published(aid)
        else:
            self.set_status(aid, None)

        self.save(audiobook, out_file.name)
Exemple #18
0
def command_update(args):
    def write_to(out):
        config.output(out)
    library = KeyLibrary(args.key_directory)
    with open(args.config_file) as fd:
        config = SedgeEngine(library, fd, not args.no_verify, url=args.config_file)
    if args.output_file == '-':
        write_to(ConfigOutput(sys.stdout))
        return
    if not check_or_confirm_overwrite(args.output_file):
        print("Aborting.", file=sys.stderr)
        sys.exit(1)

    tmpf = NamedTemporaryFile(mode='w', dir=os.path.dirname(args.output_file), delete=False)
    try:
        tmpf.file.write('''\
# :sedge:
#
# this configuration generated from `sedge' file:
# %s
#
# do not edit this file manually, edit the source file and re-run `sedge'
#

''' % (args.config_file))
        write_to(ConfigOutput(tmpf.file))
        tmpf.close()
        if args.verbose:
            diff_config_changes(args.output_file, tmpf.name)
        os.rename(tmpf.name, args.output_file)
    except:
        os.unlink(tmpf.name)
        raise
Exemple #19
0
    def parallelize(self, c, numSlices=None):
        """
        Distribute a local Python collection to form an RDD.

        >>> sc.parallelize(range(5), 5).glom().collect()
        [[0], [1], [2], [3], [4]]
        """
        numSlices = numSlices or self.defaultParallelism
        # Calling the Java parallelize() method with an ArrayList is too slow,
        # because it sends O(n) Py4J commands.  As an alternative, serialized
        # objects are written to a file and loaded through textFile().
        tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
        # Make sure we distribute data evenly if it's smaller than self.batchSize
        if "__len__" not in dir(c):
            c = list(c)    # Make it a list so we can compute its length
        batchSize = min(len(c) // numSlices, self._batchSize)
        if batchSize > 1:
            serializer = BatchedSerializer(self._unbatched_serializer,
                                           batchSize)
        else:
            serializer = self._unbatched_serializer
        serializer.dump_stream(c, tempFile)
        tempFile.close()
        readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
        jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
        return RDD(jrdd, self, serializer)
Exemple #20
0
def test_toy_corpus():

    keats = ('She dwells with Beauty - Beauty that must die;\n\n'
             'And Joy, whose hand is ever at his lips\n\n' 
             'Bidding adieu; and aching Pleasure nigh,\n\n'
             'Turning to poison while the bee-mouth sips:\n\n'
             'Ay, in the very temple of Delight\n\n'
             'Veil\'d Melancholy has her sovran shrine,\n\n'
             'Though seen of none save him whose strenuous tongue\n\n'
             'Can burst Joy\'s grape against his palate fine;\n\n'
             'His soul shall taste the sadness of her might,\n\n'
             'And be among her cloudy trophies hung.')

    assert toy_corpus(keats)
    assert toy_corpus(keats, nltk_stop=True)
    assert toy_corpus(keats, stop_freq=1)
    assert toy_corpus(keats, add_stop=['and', 'with'])
    assert toy_corpus(keats, nltk_stop=True,
                      stop_freq=1, add_stop=['ay'])

    import os
    from tempfile import NamedTemporaryFile as NFT

    tmp = NFT(delete=False)
    tmp.write(keats)
    tmp.close()

    c = toy_corpus(tmp.name, is_filename=True, 
                   nltk_stop=True, add_stop=['ay'])
    
    assert c
    os.remove(tmp.name)

    return c
class TestTrio(object):
    """Test class for testing how the individual class behave"""
    
    def setup_class(self):
        """Setup a standard trio."""
        trio_lines = ['#Standard trio\n', 
                    '#FamilyID\tSampleID\tFather\tMother\tSex\tPhenotype\n', 
                    'healthyParentsAffectedSon\tproband\tfather\tmother\t1\t2\n',
                    'healthyParentsAffectedSon\tmother\t0\t0\t2\t1\n', 
                    'healthyParentsAffectedSon\tfather\t0\t0\t1\t1\n'
                    ]
        self.trio_file = NamedTemporaryFile(mode='w+t', delete=False, suffix='.vcf')
        self.trio_file.writelines(trio_lines)
        self.trio_file.seek(0)
        self.trio_file.close()
        
    
    def test_standard_trio(self):
        """Test if the file is parsed in a correct way."""
        family_parser = parser.FamilyParser(open(self.trio_file.name, 'r'))
        assert family_parser.header == [
                                    'family_id', 
                                    'sample_id', 
                                    'father_id', 
                                    'mother_id', 
                                    'sex', 
                                    'phenotype'
                                    ]
        assert 'healthyParentsAffectedSon' in family_parser.families
        assert set(['proband', 'mother', 'father']) == set(family_parser.families['healthyParentsAffectedSon'].individuals.keys())
        assert set(['proband', 'mother', 'father']) == set(family_parser.families['healthyParentsAffectedSon'].trios[0])
    def execute(self, context):
        hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)
        logging.info("Extracting data from Hive")
        logging.info(self.sql)

        if self.bulk_load:
            tmpfile = NamedTemporaryFile()
            hive.to_csv(self.sql, tmpfile.name, delimiter='\t',
                lineterminator='\n', output_header=False)
        else:
            results = hive.get_records(self.sql)

        mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
        if self.mysql_preoperator:
            logging.info("Running MySQL preoperator")
            mysql.run(self.mysql_preoperator)

        logging.info("Inserting rows into MySQL")

        if self.bulk_load:
            mysql.bulk_load(table=self.mysql_table, tmp_file=tmpfile.name)
            tmpfile.close()
        else:
            mysql.insert_rows(table=self.mysql_table, rows=results)

        if self.mysql_postoperator:
            logging.info("Running MySQL postoperator")
            mysql.run(self.mysql_postoperator)

        logging.info("Done.")
 def roundtrip(self, dtype, x, suffix):
     f = NamedTemporaryFile(suffix='.' + suffix)
     fname = f.name
     f.close()
     sio.imsave(fname, x)
     y = sio.imread(fname)
     assert_array_equal(y, x)
def testFileAnnotationSpeed(author_testimg_generated, gatewaywrapper):
    """ Tests speed of loading file annotations. See PR: 4176 """
    try:
        f = NamedTemporaryFile()
        f.write("testFileAnnotationSpeed text")
        ns = TESTANN_NS
        image = author_testimg_generated

        # use the same file to create many file annotations
        for i in range(20):
            fileAnn = gatewaywrapper.gateway.createFileAnnfromLocalFile(f.name, mimetype="text/plain", ns=ns)
            image.linkAnnotation(fileAnn)
    finally:
        f.close()

    now = time.time()
    for ann in image.listAnnotations():
        if ann._obj.__class__ == omero.model.FileAnnotationI:
            # mimmic behaviour of templates which call multiple times
            print ann.getId()
            print ann.getFileName()
            print ann.getFileName()
            print ann.getFileSize()
            print ann.getFileSize()
    print time.time() - now
Exemple #25
0
def SubmitPlay(request):    
    import jinja2
    from tempfile import NamedTemporaryFile
    import os

   
    html = ''

    inventory = """
    [current]
    {{ public_ip_address }}
    """
    # for (name,value) in request.GET:
        
        # if name=='Servers':
            # html=html+str(value+'\n')
   

    html=request.GET['Severs']
    inventory_template = jinja2.Template(inventory)
    rendered_inventory = inventory_template.render({
        'public_ip_address': html    
        # and the rest of our variables
    })

    # Create a temporary file and write the template string to it
    hosts = NamedTemporaryFile(delete=False)
    hosts.write(rendered_inventory)
    hosts.close()
    
    print(hosts.name)

    import ansiblepythonapi as myPlay
    args=['/home/ec2-user/playss/AnsiblePlus/test.yml']
    # args.append('-i')
    # args.append(hosts.name)
    message=myPlay.main(args)

    objects=[]
    for runner_results in myPlay.message:              
        values=[]
        for (host, value) in runner_results.get('dark', {}).iteritems():
            try:
                values.append(host)
                values.append(value['failed'])
                values.append(value['msg'])    
                objects.append(values)
            except:
                pass
        for (host, value) in runner_results.get('contacted', {}).iteritems():
            try:
                values.append(host)
                values.append(value['failed'])
                values.append(value['msg'])    
                objects.append(values)
            except:
                pass
        # for msg in pb.stats.output():   
    context=Context({'Summary':objects})
    return render(request, 'AnsibleResponce.html',context)
Exemple #26
0
    def test_img(self):
        
        phpbd_pwd = randstr(4)
        temp_file = NamedTemporaryFile(); temp_file.close(); 
        temp_imgpathname = '%s.gif' % temp_file.name 
        temp_path, temp_filename = os.path.split(temp_imgpathname)
        
        temp_outputdir = mkdtemp()
        
        status, output = getstatusoutput(conf['env_create_backdoorable_img'] % temp_imgpathname)
        self.assertEqual(0, status)        
        
        self.assertEqual(self._res(':generate.img %s %s'  % (phpbd_pwd, temp_imgpathname)), [os.path.join('bd_output',temp_filename), 'bd_output/.htaccess'])
        self.assertTrue(os.path.isdir('bd_output'))
        shutil.rmtree('bd_output')
        
        self.assertRegexpMatches(self._warn(':generate.img %s /tmp/sdalkj'  % (phpbd_pwd)), modules.generate.img.WARN_IMG_NOT_FOUND)
        self.assertRegexpMatches(self._warn(':generate.img %s %s /tmp/ksdajhjksda/kjdha'  % (phpbd_pwd, temp_imgpathname)), modules.generate.img.WARN_DIR_CREAT)
        self.assertRegexpMatches(self._warn(':generate.img [@>!?] %s %s3'  % (temp_imgpathname, temp_outputdir)), core.backdoor.WARN_CHARS)

        self.assertEqual(self._res(':generate.img %s %s %s'  % (phpbd_pwd, temp_imgpathname, temp_outputdir)), [os.path.join(temp_outputdir,temp_filename), os.path.join(temp_outputdir, '.htaccess')])


        # No output expected 
        self.assertEqual(self._outp(':generate.img %s %s %s'  % (phpbd_pwd, temp_imgpathname, temp_outputdir+'2')), '')

        self.__class__._env_chmod(temp_outputdir, '777', currentuser=True)
        self.__class__._env_cp(os.path.join(temp_outputdir, '.htaccess'), '.htaccess')

        self.__test_new_bd( os.path.join(temp_outputdir,temp_filename), temp_filename, phpbd_pwd)
def hmmscan(fasta, database_path, ncpus=10):

    F = NamedTemporaryFile()
    F.write(fasta)
    F.flush()
    OUT = NamedTemporaryFile()
    cmd = '%s --cpu %s -o /dev/null -Z 190000 --tblout %s %s %s' %(HMMSCAN, ncpus, OUT.name, database_path, F.name)
    #print cmd
    sts = subprocess.call(cmd, shell=True)
    byquery = defaultdict(list)

    if sts == 0:
        for line in OUT:
            #['#', '---', 'full', 'sequence', '----', '---', 'best', '1', 'domain', '----', '---', 'domain', 'number', 'estimation', '----']
            #['#', 'target', 'name', 'accession', 'query', 'name', 'accession', 'E-value', 'score', 'bias', 'E-value', 'score', 'bias', 'exp', 'reg', 'clu', 'ov', 'env', 'dom', 'rep', 'inc', 'description', 'of', 'target']
            #['#-------------------', '----------', '--------------------', '----------', '---------', '------', '-----', '---------', '------', '-----', '---', '---', '---', '---', '---', '---', '---', '---', '---------------------']
            #['delNOG20504', '-', '553220', '-', '1.3e-116', '382.9', '6.2', '3.4e-116', '381.6', '6.2', '1.6', '1', '1', '0', '1', '1', '1', '1', '-']
            if line.startswith('#'): continue
            fields = line.split() # output is not tab delimited! Should I trust this split?
            hit, _, query, _ , evalue, score, bias, devalue, dscore, dbias = fields[0:10]
            evalue, score, bias, devalue, dscore, dbias = map(float, [evalue, score, bias, devalue, dscore, dbias])
            byquery[query].append([hit, evalue, score])
            
    OUT.close()
    F.close()
    return byquery
Exemple #28
0
    def _build_and_catch_errors(self, build_func, options_bytes, source=None):
        try:
            return build_func()
        except _cl.RuntimeError as e:
            msg = e.what
            if options_bytes:
                msg = msg + "\n(options: %s)" % options_bytes.decode("utf-8")

            if source is not None:
                from tempfile import NamedTemporaryFile
                srcfile = NamedTemporaryFile(mode="wt", delete=False, suffix=".cl")
                try:
                    srcfile.write(source)
                finally:
                    srcfile.close()

                msg = msg + "\n(source saved as %s)" % srcfile.name

            code = e.code
            routine = e.routine

            err = _cl.RuntimeError(
                    _cl.Error._ErrorRecord(
                        msg=msg,
                        code=code,
                        routine=routine))

        # Python 3.2 outputs the whole list of currently active exceptions
        # This serves to remove one (redundant) level from that nesting.
        raise err
Exemple #29
0
def run_via_pbs(args, pbs):
    assert(pbs in ('condor',))  # for now

    # TODO: RF to support multiple backends, parameters, etc, for now -- just condor, no options
    f = NamedTemporaryFile('w', prefix='datalad-%s-' % pbs, suffix='.submit', delete=False)
    try:
        pwd = getpwd()
        logs = f.name.replace('.submit', '.log')
        exe = args[0]
        # TODO: we might need better way to join them, escaping spaces etc.  There must be a stock helper
        #exe_args = ' '.join(map(repr, args[1:])) if len(args) > 1 else ''
        exe_args = ' '.join(args[1:]) if len(args) > 1 else ''
        f.write("""\
Executable = %(exe)s
Initialdir = %(pwd)s
Output = %(logs)s
Error = %(logs)s
getenv = True

arguments = %(exe_args)s
queue
""" % locals())
        f.close()
        Runner().run(['condor_submit', f.name])
        lgr.info("Scheduled execution via %s.  Logs will be stored under %s" % (pbs, logs))
    finally:
        os.unlink(f.name)
Exemple #30
0
    def _spawnAsBatch(self, processProtocol, executable, args, env,
                      path, usePTY):
        """A cheat that routes around the impedance mismatch between
        twisted and cmd.exe with respect to escaping quotes"""

        tf = NamedTemporaryFile(dir='.', suffix=".bat", delete=False)
        # echo off hides this cheat from the log files.
        tf.write("@echo off\n")
        if isinstance(self.command, basestring):
            tf.write(self.command)
        else:
            tf.write(win32_batch_quote(self.command))
        tf.close()

        argv = os.environ['COMSPEC'].split()  # allow %COMSPEC% to have args
        if '/c' not in argv:
            argv += ['/c']
        argv += [tf.name]

        def unlink_temp(result):
            os.unlink(tf.name)
            return result
        self.deferred.addBoth(unlink_temp)

        return reactor.spawnProcess(processProtocol, executable, argv, env,
                                    path, usePTY=usePTY)
Exemple #31
0
class Editor(object):
    """
    This class is used to use an editor over the connection, and then return the edited file contents.  The
    files are temporary files, and are deleted when we've finished with them, and the editors are run in a
    restricted mode so they can only edit that file.

    To use:
    callback = self.handle_this_edited_value   # signature of void callback(str)
    editor can be "vim" or "nano", default of "nano"
    editorObj = Editor(connection, editor, callback)
    editorObj.launch(initial_contents)

    The calling tasklet can then yield as the user input loop will trigger the original handler on further input
    after the callback has been called to set the value of whatever variable being edited.
    """
    editors = {
        "nano": ["nano", "-R"],
        "vim": ["vim", "-Z"],
    }

    def __init__(self, connection, editor, callback=None):
        from HavokMud.startup import server_instance
        self.connection = connection
        self.server = server_instance
        self.editor = editor
        self.channel = stackless.channel()
        self.callback_channel = stackless.channel()
        self.file_ = None
        self.editor_callback = partial(self.editor_callback_wrapper, callback)

    def launch(self, initial_contents=None):
        command = self.editors.get(self.editor, None)
        if not command:
            raise NotImplementedError("Editor %s is not configured" %
                                      self.editor)
        self.file_ = NamedTemporaryFile("w+", delete=False)
        if initial_contents:
            self.file_.write(initial_contents)
            self.file_.flush()
            self.file_.seek(0)
        command += self.file_.name
        self.connection.handler = ExternalHandler(self.connection, command,
                                                  self.channel,
                                                  self.handler_callback,
                                                  self.editor_callback)

    def handler_callback(self):
        self.channel.receive()
        self.file_.seek(0)
        contents = self.file_.read()
        filename = self.file_.name
        self.file_.close()
        os.unlink(filename)
        self.callback_channel.send(contents)

    def default_editor_callback(self):
        contents = self.callback_channel.receive()
        return contents

    def editor_callback_wrapper(self, callback):
        if callback and hasattr(callback, "__call__"):
            callback(self.default_editor_callback())
Exemple #32
0
    def predict(self, aa_seq, command=None, options=None, **kwargs):
        """
        Overwrites ACleavageSitePrediction.predict

        :param aa_seq: A list of or a single :class:`~Fred2.Core.Peptide.Peptide` or :class:`~Fred2.Core.Protein.Protein` object
        :type aa_seq: list(:class:`~Fred2.Core.Peptide.Peptide`/:class:`~Fred2.Core.Protein.Protein`) or :class:`~Fred2.Core.Peptide.Peptide`/:class:`~Fred2.Core.Protein.Protein`
        :param str command: The path to a alternative binary (can be used if binary is not globally executable)
        :param str options: A string of additional options directly past to the external tool
        :return: A :class:`~Fred2.Core.CleavageSitePredictionResult` object
        :rtype: :class:`~Fred2.Core.CleavageSitePredictionResult`
        """
        if not self.is_in_path() and "path" not in kwargs:
            raise RuntimeError("{name} {version} could not be found in PATH".format(name=self.name,
                                                                                    version=self.version))
        external_version = self.get_external_version(path=command)
        if self.version != external_version and external_version is not None:
            raise RuntimeError("Internal version {internal_version} does "
                               "not match external version {external_version}".format(internal_version=self.version,
                                                                                      external_version=external_version))

        if isinstance(aa_seq, Peptide) or isinstance(aa_seq, Protein):
            pep_seqs = {str(aa_seq): aa_seq}
        else:
            pep_seqs = {}
            for p in aa_seq:
                if not isinstance(p, Peptide) and not isinstance(p, Protein):
                    raise ValueError("Input is not of type Protein or Peptide")
                pep_seqs[str(p)] = p

        tmp_out = NamedTemporaryFile(delete=False)
        tmp_file = NamedTemporaryFile(delete=False)
        self.prepare_input(pep_seqs.iterkeys(), tmp_file)
        tmp_file.close()

        #allowe customary executable specification
        if command is not None:
            exe = self.command.split()[0]
            _command = self.command.replace(exe, command)
        else:
            _command = self.command

        try:
            stdo = None
            stde = None
            cmd = _command.format(input=tmp_file.name, options="" if options is None else options, out=tmp_out.name)
            p = subprocess.Popen(cmd, shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            stdo, stde = p.communicate()
            stdr = p.returncode
            if stdr > 0:
                raise RuntimeError("Unsuccessful execution of " + cmd + " (EXIT!=0) with error: " + stde)
        except Exception as e:
            raise RuntimeError(e)

        result = self.parse_external_result(tmp_out)

        df_result = CleavageSitePredictionResult.from_dict(result)
        df_result.index = pandas.MultiIndex.from_tuples([tuple((i,j)) for i, j in df_result.index],
                                                        names=['ID', 'Pos'])
        os.remove(tmp_file.name)
        tmp_out.close()
        os.remove(tmp_out.name)

        return df_result
Exemple #33
0
def download_database(data_root=None, clobber=False):
    """
    Download a SQLite database containing the limb darkening coefficients
    computed by `Claret & Bloemen (2011)
    <http://adsabs.harvard.edu/abs/2011A%26A...529A..75C>`_. The table is
    available online on `Vizier
    <http://vizier.cfa.harvard.edu/viz-bin/VizieR?-source=J/A+A/529/A75>`_.
    Using the ASCII data table, the SQLite database was generated with the
    following Python commands:

    .. code-block:: python

        import sqlite3
        import numpy as np

        with sqlite3.connect("ldcoeffs.db") as conn:
            c = conn.cursor()
            c.execute("CREATE TABLE IF NOT EXISTS claret11 "
                    "(teff REAL, logg REAL, feh REAL, veloc REAL, mu1 REAL, "
                    "mu2 REAL)")
            data = np.loadtxt("claret11.txt", skiprows=59, delimiter="|",
                            usecols=range(1, 7))
            c.executemany("INSERT INTO claret11 (logg,teff,feh,veloc,mu1,mu2) "
                        "VALUES (?,?,?,?,?,?)", data)

    """
    # Figure out the local filename for the database.
    if data_root is None:
        data_root = KPLR_ROOT
    filename = os.path.join(data_root, DB_FILENAME)

    if not clobber and os.path.exists(filename):
        return filename

    # Make sure that the target directory exists.
    try:
        os.makedirs(data_root)
    except os.error:
        pass

    # MAGIC: specify the URL for the remote file.
    url = "http://bbq.dfm.io/~dfm/ldcoeffs.db"

    # Fetch the database from the server.
    logging.info("Downloading file from: '{0}'".format(url))
    r = urllib2.Request(url)
    handler = urllib2.urlopen(r)
    code = handler.getcode()
    if int(code) != 200:
        raise RuntimeError(
            "Couldn't download file from {0}. Returned: {1}".format(url, code))

    # Save the contents of the file.
    logging.info("Saving file to: '{0}'".format(filename))

    # Atomically write to disk.
    # http://stackoverflow.com/questions/2333872/ \
    #        atomic-writing-to-file-with-python
    f = NamedTemporaryFile("wb", delete=False)
    f.write(handler.read())
    f.flush()
    os.fsync(f.fileno())
    f.close()
    shutil.move(f.name, filename)

    return filename
Exemple #34
0
def mktemp():
    f = NamedTemporaryFile(delete=False)
    f.close()
    loc = path(f.name).abspath()
    loc.remove()
    return loc
Exemple #35
0
class TestPasswordFile(unittest.TestCase):
    def setUp(self):
        self.config_file = NamedTemporaryFile(mode="w", delete=False)
        self.config_file.close()
        self.pwd_file = NamedTemporaryFile(mode="w", delete=False)
        self.pwd_file.close()

    def tearDown(self):
        os.remove(self.config_file.name)
        self.config_file = None
        os.remove(self.pwd_file.name)
        self.pwd_file = None

    def _set_file(self, file_name, value):
        with open(file_name, "w") as f:
            f.write(value)

    def test_given_no_pwd_file_expect_empty_credentials_list(self):
        self._set_file(self.config_file.name, "[TabPy]\n"
                       "TABPY_TRANSFER_PROTOCOL = http")

        app = TabPyApp(self.config_file.name)
        self.assertDictEqual(
            app.credentials,
            {},
            "Expected no credentials with no password file provided",
        )

    def test_given_empty_pwd_file_expect_app_fails(self):
        self._set_file(self.config_file.name, "[TabPy]\n"
                       f"TABPY_PWD_FILE = {self.pwd_file.name}")

        self._set_file(self.pwd_file.name, "# just a comment")

        with self.assertRaises(RuntimeError) as cm:
            TabPyApp(self.config_file.name)
            ex = cm.exception
            self.assertEqual(
                f"Failed to read password file {self.pwd_file.name}",
                ex.args[0])

    def test_given_missing_pwd_file_expect_app_fails(self):
        self._set_file(self.config_file.name, "[TabPy]\n"
                       "TABPY_PWD_FILE = foo")

        with self.assertRaises(RuntimeError) as cm:
            TabPyApp(self.config_file.name)
            ex = cm.exception
            self.assertEqual(
                f"Failed to read password file {self.pwd_file.name}",
                ex.args[0])

    def test_given_one_password_in_pwd_file_expect_one_credentials_entry(self):
        self._set_file(self.config_file.name, "[TabPy]\n"
                       f"TABPY_PWD_FILE = {self.pwd_file.name}")

        login = "******"
        pwd = "someting@something_else"
        self._set_file(self.pwd_file.name, "# passwords\n"
                       "\n"
                       f"{login} {pwd}")

        app = TabPyApp(self.config_file.name)

        self.assertEqual(len(app.credentials), 1)
        self.assertIn(login, app.credentials)
        self.assertEqual(app.credentials[login], pwd)

    def test_given_username_but_no_password_expect_parsing_fails(self):
        self._set_file(self.config_file.name, "[TabPy]\n"
                       f"TABPY_PWD_FILE = {self.pwd_file.name}")

        login = "******"
        pwd = ""
        self._set_file(self.pwd_file.name, "# passwords\n"
                       "\n"
                       f"{login} {pwd}")

        with self.assertRaises(RuntimeError) as cm:
            TabPyApp(self.config_file.name)
            ex = cm.exception
            self.assertEqual(
                f"Failed to read password file {self.pwd_file.name}",
                ex.args[0])

    def test_given_duplicate_usernames_expect_parsing_fails(self):
        self._set_file(self.config_file.name, "[TabPy]\n"
                       f"TABPY_PWD_FILE = {self.pwd_file.name}")

        login = "******"
        pwd = "hashedpw"
        self._set_file(self.pwd_file.name, "# passwords\n"
                       "\n"
                       f"{login} {pwd}\n{login} {pwd}")

        with self.assertRaises(RuntimeError) as cm:
            TabPyApp(self.config_file.name)
            ex = cm.exception
            self.assertEqual(
                f"Failed to read password file {self.pwd_file.name}",
                ex.args[0])

    def test_given_one_line_with_too_many_params_expect_app_fails(self):
        self._set_file(self.config_file.name, "[TabPy]\n"
                       f"TABPY_PWD_FILE = {self.pwd_file.name}")

        self._set_file(
            self.pwd_file.name,
            "# passwords\n"
            "user1 pwd1\n"
            "user_2 pwd#2"
            "user1 pwd@3",
        )

        with self.assertRaises(RuntimeError) as cm:
            TabPyApp(self.config_file.name)
            ex = cm.exception
            self.assertEqual(
                f"Failed to read password file {self.pwd_file.name}",
                ex.args[0])

    def test_given_different_cases_in_pwd_file_expect_app_fails(self):
        self._set_file(self.config_file.name, "[TabPy]\n"
                       f"TABPY_PWD_FILE = {self.pwd_file.name}")

        self._set_file(
            self.pwd_file.name,
            "# passwords\n"
            "user1 pwd1\n"
            "user_2 pwd#2"
            "UseR1 pwd@3",
        )

        with self.assertRaises(RuntimeError) as cm:
            TabPyApp(self.config_file.name)
            ex = cm.exception
            self.assertEqual(
                f"Failed to read password file {self.pwd_file.name}",
                ex.args[0])

    def test_given_multiple_credentials_expect_all_parsed(self):
        self._set_file(self.config_file.name, "[TabPy]\n"
                       f"TABPY_PWD_FILE = {self.pwd_file.name}")
        creds = {"user_1": "pwd_1", "user@2": "pwd@2", "user#3": "pwd#3"}

        pwd_file_context = ""
        for login in creds:
            pwd_file_context += f"{login} {creds[login]}\n"

        self._set_file(self.pwd_file.name, pwd_file_context)
        app = TabPyApp(self.config_file.name)

        self.assertCountEqual(creds, app.credentials)
        for login in creds:
            self.assertIn(login, app.credentials)
            self.assertEqual(creds[login], app.credentials[login])
def map_visualize(df_house):
	# location data for CMU (Pittsburgh)
	cmu_name = 'CMU'
	cmu_lat = 40.444229
	cmu_lng = -79.943367


	# load restaurant data
	df_restaurant = pd.read_csv('restaurant/restaurant.csv')
	# print(df_restaurant.head())

	
	# create map
	cmu_map = folium.Map(location=[cmu_lat, cmu_lng], zoom_start=14)
	
	
	folium.Marker([cmu_lat, cmu_lng], popup=cmu_name, color='red').add_to(cmu_map)
	for i in range(df_house.shape[0]):
		# add pop-up text to the apartment on the map
		Name = df_house.loc[i, 'Street']
		Region = df_house.loc[i, 'Region']
		Price = df_house.loc[i, 'Price']
		Bedrooms = df_house.loc[i, 'Bedrooms']
		Bathrooms = df_house.loc[i, 'Bathrooms']
		Floorspace = df_house.loc[i, 'Floorspace']
		Pet_friendly = df_house.loc[i, 'Pet_friendly']
		Furnished = df_house.loc[i, 'Furnished']
		trans_time_driv = df_house.loc[i, 'trans_time_driv']
		trans_time_walk = df_house.loc[i, 'trans_time_walk']
		trans_time_bike = df_house.loc[i, 'trans_time_bike']
		apt_lat = df_house.loc[i, 'lat']
		apt_lng = df_house.loc[i, 'lng']
		info = """{}, {}\nPrice: {}\nBedrooms: {}\nBathrooms: {}\nFloorspace: {}\nPet friendly: {}\nFurnished: {}\nNumber of restaurants: {}\nAverage star of restaurants: {:.2f}\n
				""".format(df_house.loc[i,'Street'],
						   df_house.loc[i,'Region'],
						   df_house.loc[i,'Price'],
						   df_house.loc[i,'Bedrooms'],
						   df_house.loc[i,'Bathrooms'],
						   df_house.loc[i,'Floorspace'],
						   df_house.loc[i,'Pet_friendly'],
						   df_house.loc[i,'Furnished'],
						   df_house.loc[i,'restaurant_num'],
						   df_house.loc[i,'restaurant_star'])
		folium.Marker([apt_lat, apt_lng], popup=info).add_to(cmu_map)
	
		# instantiate a feattaure group for the restaurants in the dataframe
		incidents = folium.map.FeatureGroup()
		
		# add each to the resurant feature group
		for lat, lng, in zip(df_restaurant.latitude, df_restaurant.longitude):
		    if abs(lat - apt_lat) < 0.005 and abs(lng - apt_lng) < 0.0025:
		        incidents.add_child(
		            folium.CircleMarker(
		                [lat, lng],
		                radius=5, # define how big the circle markers to be
		                color='yellow',
		                fill=True,
		                fill_color='blue',
		                fill_opacity=0.6
		            )
		        ) 
		        
		# add restaurants to map
		cmu_map.add_child(incidents)

	# save the visualization into the temp file and render it
	tmp = NamedTemporaryFile(mode='w', delete=False)
	cmu_map.save(tmp.name)
	tmp.close()
	with open(tmp.name) as f:
	    folium_map_html = f.read()
	
	os.unlink(tmp.name) # delete tmp file, so no garbage remained after program ends
	run_html_server(folium_map_html)
Exemple #37
0
    def _run_program(self, bin, fastafile, params=None):
        """
        Run MEME and predict motifs from a FASTA file.

        Parameters
        ----------
        bin : str
            Command used to run the tool.

        fastafile : str
            Name of the FASTA input file.

        params : dict, optional
            Optional parameters. For some of the tools required parameters
            are passed using this dictionary.

        Returns
        -------
        motifs : list of Motif instances
            The predicted motifs.

        stdout : str
            Standard out of the tool.

        stderr : str
            Standard error of the tool.
        """
        default_params = {"width": 10, "single": False, "number": 10}
        if params is not None:
            default_params.update(params)

        tmp = NamedTemporaryFile(dir=self.tmpdir)

        strand = "-revcomp"
        width = default_params["width"]
        number = default_params["number"]

        cmd = [
            bin,
            fastafile,
            "-text",
            "-dna",
            "-nostatus",
            "-mod",
            "zoops",
            "-nmotifs",
            "%s" % number,
            "-w",
            "%s" % width,
            "-maxsize",
            "10000000",
        ]
        if not default_params["single"]:
            cmd.append(strand)

        # Fix to run in Docker
        env = os.environ.copy()
        env["OMPI_MCA_plm_rsh_agent"] = "sh"

        p = Popen(cmd, bufsize=1, stderr=PIPE, stdout=PIPE, env=env)
        stdout, stderr = p.communicate()

        motifs = []
        motifs = self.parse(io.StringIO(stdout.decode()))

        # Delete temporary files
        tmp.close()

        return motifs, stdout, stderr
Exemple #38
0
def test_cli_trained_model_can_be_saved(tmpdir):
    cmd = None
    lang = 'nl'
    output_dir = str(tmpdir)
    train_file = NamedTemporaryFile('wb', dir=output_dir, delete=False)
    train_corpus = [
        {
            "id": "identifier_0",
            "paragraphs": [
                {
                    "raw": "Jan houdt van Marie.\n",
                    "sentences": [
                        {
                            "tokens": [
                                {
                                    "id": 0,
                                    "dep": "nsubj",
                                    "head": 1,
                                    "tag": "NOUN",
                                    "orth": "Jan",
                                    "ner": "B-PER"
                                },
                                {
                                    "id": 1,
                                    "dep": "ROOT",
                                    "head": 0,
                                    "tag": "VERB",
                                    "orth": "houdt",
                                    "ner": "O"
                                },
                                {
                                    "id": 2,
                                    "dep": "case",
                                    "head": 1,
                                    "tag": "ADP",
                                    "orth": "van",
                                    "ner": "O"
                                },
                                {
                                    "id": 3,
                                    "dep": "obj",
                                    "head": -2,
                                    "tag": "NOUN",
                                    "orth": "Marie",
                                    "ner": "B-PER"
                                },
                                {
                                    "id": 4,
                                    "dep": "punct",
                                    "head": -3,
                                    "tag": "PUNCT",
                                    "orth": ".",
                                    "ner": "O"
                                },
                                {
                                    "id": 5,
                                    "dep": "",
                                    "head": -1,
                                    "tag": "SPACE",
                                    "orth": "\n",
                                    "ner": "O"
                                }
                            ],
                            "brackets": []
                        }
                    ]
                }
            ]
        }
    ]

    train_file.write(json.dumps(train_corpus).encode('utf-8'))
    train_file.close()
    train_data = train_file.name
    dev_data = train_data

    # spacy train -n 1 -g -1 nl output_nl training_corpus.json training \
    # corpus.json
    train(cmd, lang, output_dir, train_data, dev_data, n_iter=1)

    assert True
Exemple #39
0
class Runner(object):
    def __init__(self,
                 hostnames,
                 playbook,
                 private_key_file,
                 run_data,
                 become_pass,
                 verbosity=0):

        self.run_data = run_data
        self.options = Options()
        self.options.private_key_file = private_key_file
        self.options.verbosity = verbosity
        self.options.connection = 'ssh'  # Need a connection type "smart" or "ssh"
        self.options.become = True
        self.options.become_method = 'sudo'
        self.options.become_user = '******'

        # Set global verbosity
        self.display = Display()
        self.display.verbosity = self.options.verbosity
        # Executor appears to have it's own
        # verbosity object/setting as well
        playbook_executor.verbosity = self.options.verbosity

        # Become Pass Needed if not logging in as user root
        passwords = {'become_pass': become_pass}

        # Gets data from YAML/JSON files
        self.loader = DataLoader()
        self.loader.set_vault_password(os.environ['VAULT_PASS'])

        # All the variables from all the various places
        self.variable_manager = VariableManager()
        self.variable_manager.extra_vars = self.run_data

        # Parse hosts, I haven't found a good way to
        # pass hosts in without using a parsed template :(
        # (Maybe you know how?)
        self.hosts = NamedTemporaryFile(delete=False)
        self.hosts.write("""[run_hosts]
                            %s
                                """ % hostnames)
        self.hosts.close()

        # This was my attempt to pass in hosts directly.
        #
        # Also Note: In py2.7, "isinstance(foo, str)" is valid for
        #            latin chars only. Luckily, hostnames are
        #            ascii-only, which overlaps latin charset
        ## if isinstance(hostnames, str):
        ##     hostnames = {"customers": {"hosts": [hostnames]}}

        # Set inventory, using most of above objects
        self.inventory = Inventory(loader=self.loader,
                                   variable_manager=self.variable_manager,
                                   host_list=self.hosts.name)
        self.variable_manager.set_inventory(self.inventory)

        # Playbook to run. Assumes it is
        # local to this python file
        pb_dir = os.path.dirname(__file__)
        playbook = "%s/%s" % (pb_dir, playbook)

        # Setup playbook executor, but don't run until run() called
        self.pbex = playbook_executor.PlaybookExecutor(
            playbooks=[playbook],
            inventory=self.inventory,
            variable_manager=self.variable_manager,
            loader=self.loader,
            options=self.options,
            passwords=passwords)

    def run(self):
        # Results of PlaybookExecutor
        self.pbex.run()
        stats = self.pbex._tqm._stats

        # Test if success for record_logs
        run_success = True
        hosts = sorted(stats.processed.keys())
        for h in hosts:
            t = stats.summarize(h)
            if t['unreachable'] > 0 or t['failures'] > 0:
                run_success = False

        # Dirty hack to send callback to save logs with data we want
        # Note that function "record_logs" is one I created and put into
        # the playbook callback file
        self.pbex._tqm.send_callback('record_logs',
                                     user_id=self.run_data['user_id'],
                                     success=run_success)

        # Remove created temporary files
        os.remove(self.hosts.name)

        return stats
Exemple #40
0
class AbstractWrapper(object):
    '''
        abstract solver wrapper
    '''
    def __init__(self):
        '''
            Constructor
        '''
        #program_name = os.path.basename(sys.argv[0])
        program_version = "v%s" % __version__
        program_build_date = str(__updated__)
        program_version_message = "%%(prog)s %s (%s)" % (program_version,
                                                         program_build_date)
        program_shortdesc = __import__("__main__").__doc__.split("\n")[1]
        program_license = '''%s
    
          Created by %s on %s.
          Copyright 2014 - AClib. All rights reserved.
          
          Licensed under the GPLv2
          http://www.gnu.org/licenses/gpl-2.0.html
          
          Distributed on an "AS IS" basis without warranties
          or conditions of any kind, either express or implied.
        
          USAGE
        ''' % (program_shortdesc, str(__authors__), str(__date__))
        #self.parser = ArgumentParser(description=program_license, formatter_class=RawDescriptionHelpFormatter, add_help=False)
        self.parser = OArgumentParser()
        self.args = None

        self.RESULT_MAPPING = {'SUCCESS': "SAT"}
        self._watcher_file = None
        self._solver_file = None

        self._instance = ""
        self._specifics = ""
        self._cutoff = 0.0
        self._runlength = 0
        self._seed = 0
        self._config_dict = {}

        self._exit_code = None

        self._runsolver = None
        self._mem_limit = 2048
        self._tmp_dir = None
        self._tmp_dir_algo = None

        self._crashed_if_non_zero_status = True

        self._subprocesses = []

        self._DEBUG = True
        self._DELAY2KILL = 2

        self._ta_status = "EXTERNALKILL"
        self._ta_runtime = 999999999.0
        self._ta_runlength = -1
        self._ta_quality = -1
        self._ta_exit_code = None
        self._ta_misc = ""

    def print_d(self, str_):
        if self._DEBUG:
            print(str_)

    def main(self, argv=None):
        ''' parse command line'''
        if argv is None:
            argv = sys.argv
        else:
            sys.argv.extend(argv)

        try:
            signal.signal(signal.SIGTERM, signalHandler)
            signal.signal(signal.SIGQUIT, signalHandler)
            signal.signal(signal.SIGINT, signalHandler)

            # Setup argument parser

            self.parser.add_argument(
                "--runsolver-path",
                dest="runsolver",
                default="./target_algorithm/runsolver/runsolver",
                help=
                "path to runsolver binary (if None, the runsolver is deactivated)"
            )
            self.parser.add_argument(
                "--temp-file-dir",
                dest="tmp_dir",
                default=None,
                help=
                "directory for temporary files (relative to -exec-dir in SMAC scenario). If 'NONE' use $TMPDIR if available, otherwise './'"
            )
            self.parser.add_argument(
                "--temp-file-dir-algo",
                dest="tmp_dir_algo",
                default=True,
                type=bool,
                help="create a directory for temporary files from target algo"
            )  #TODO: set default to False
            self.parser.add_argument("--mem-limit",
                                     dest="mem_limit",
                                     default=self._mem_limit,
                                     type=int,
                                     help="memory limit in MB")
            self.parser.add_argument(
                "--internal",
                dest="internal",
                default=False,
                type=bool,
                help="skip calling an external target algorithm")
            self.parser.add_argument(
                "--log",
                dest="log",
                default=True,
                type=bool,
                help=
                "logs all runs in \"target_algo_runs.json\" in --temp-file-dir"
            )
            self.parser.add_argument(
                "--ext-callstring",
                dest="ext_callstring",
                default=None,
                help="Command to get call string via external program;" +
                "your programm gets a file with" +
                "first line: instance name," + "second line: seed" +
                "further lines: paramter name, paramater value;" +
                "output: one line with callstring for target algorithm")
            self.parser.add_argument(
                "--ext-parsing",
                dest="ext_parsing",
                default=None,
                help=
                "Command to use an external program to parse the output of your target algorihm;"
                + "only paramter: name of output file;" +
                "output of your progam:" +
                "status: SAT|UNSAT|TIMEOUT|CRASHED\n" +
                "quality: <integer>\n" + "misc: <string>")
            self.parser.add_argument("--help",
                                     dest="show_help",
                                     default=False,
                                     type=bool,
                                     help="shows help")

            # Process arguments
            self.args, target_args = self.parser.parse_cmd(sys.argv[1:])
            args = self.args

            if args.show_help:
                self.parser.print_help()
                self._ta_status = "ABORT"
                self._ta_misc = "help was requested..."
                self._exit_code = 1
                sys.exit(1)

            if args.runsolver != "None" and not os.path.isfile(
                    args.runsolver) and not args.internal:
                self._ta_status = "ABORT"
                self._ta_misc = "runsolver is missing - should have been at %s." % (
                    args.runsolver)
                self._exit_code = 1
                sys.exit(1)
            else:
                self._runsolver = args.runsolver
                self._mem_limit = args.mem_limit

            if args.tmp_dir is None:
                if "TMPDIR" in os.environ:
                    args.tmp_dir = os.environ["TMPDIR"]
                else:
                    args.tmp_dir = "."

            if not os.path.isdir(args.tmp_dir):
                self._ta_status = "ABORT"
                self._ta_misc = "temp directory is missing - should have been at %s." % (
                    args.tmp_dir)
                self._exit_code = 1
                sys.exit(1)
            else:
                self._tmp_dir = args.tmp_dir

            if len(target_args) < 5:
                self._ta_status = "ABORT"
                self._ta_misc = "some required TA parameters (instance, specifics, cutoff, runlength, seed) missing - was [%s]." % (
                    " ".join(target_args))
                self._exit_code = 1
                sys.exit(1)

            self._config_dict = self.build_parameter_dict(target_args)

            if args.tmp_dir_algo:
                try:
                    self._tmp_dir_algo = mkdtemp(dir="/tmp/")
                except OSError:
                    sys.stderr.write(
                        "Creating directory for temporary files failed")
                    pass

            runargs = {
                "instance": self._instance,
                "specifics": self._specifics,
                "cutoff": self._cutoff,
                "runlength": self._runlength,
                "seed": self._seed,
                "tmp": self._tmp_dir_algo
            }

            if args.ext_callstring:
                target_cmd = self.get_command_line_args_ext(
                    runargs=runargs,
                    config=self._config_dict,
                    ext_call=args.ext_callstring)
            else:
                target_cmd = self.get_command_line_args(
                    runargs=runargs, config=self._config_dict)

            target_cmd = target_cmd.split(" ")
            target_cmd = filter(lambda x: x != "", target_cmd)

            if not args.internal:
                self.call_target(target_cmd)
                self.read_runsolver_output()

            try:
                if "core" in os.listdir("."):
                    os.remove("core")
            except:
                traceback.print_exc()

            if args.ext_parsing:
                resultMap = self.process_results_ext(
                    self._solver_file, {"exit_code": self._ta_exit_code},
                    ext_call=args.ext_parsing)
            else:
                resultMap = self.process_results(
                    self._solver_file, {"exit_code": self._ta_exit_code})

            if ('status' in resultMap):
                self._ta_status = self.RESULT_MAPPING.get(
                    resultMap['status'], resultMap['status'])
            if ('runtime' in resultMap):
                self._ta_runtime = resultMap['runtime']
            if ('quality' in resultMap):
                self._ta_quality = resultMap['quality']
            if 'misc' in resultMap and not self._ta_misc:
                self._ta_misc = resultMap['misc']
            if 'misc' in resultMap and self._ta_misc:
                self._ta_misc += " - " + resultMap['misc']

            # if still no status was determined, something went wrong and output files should be kept
            if self._ta_status is "EXTERNALKILL":
                self._ta_status = "CRASHED"
            sys.exit()
        except (KeyboardInterrupt, SystemExit):
            self.cleanup()
            self.print_result_string()
            #traceback.print_exc()
            if self._ta_exit_code:
                sys.exit(self._ta_exit_code)
            elif self._exit_code:
                sys.exit(self._exit_code)
            else:
                sys.exit(0)

    def build_parameter_dict(self, arg_list):
        '''
            Reads all arguments which were not parsed by ArgumentParser,
            extracts all meta information
            and builds a mapping: parameter name -> parameter value
            Format Assumption: <instance> <specifics> <runtime cutoff> <runlength> <seed> <solver parameters>
            Args:
                list of all options not parsed by ArgumentParser
        '''
        self._instance = arg_list[0]
        self._specifics = arg_list[1]
        self._cutoff = int(float(arg_list[2]) +
                           1)  # runsolver only rounds down to integer
        self._ta_runtime = self._cutoff
        self._runlength = int(arg_list[3])
        self._seed = int(arg_list[4])

        params = arg_list[5:]
        if (len(params) / 2) * 2 != len(params):
            self._ta_status = "ABORT"
            self._ta_misc = "target algorithm parameter list MUST have even length - found %d arguments." % (
                len(params))
            self.print_d(" ".join(params))
            self._exit_code = 1
            sys.exit(1)

        return dict((name, value.strip("'"))
                    for name, value in zip(params[::2], params[1::2]))

    def call_target(self, target_cmd):
        '''
            extends the target algorithm command line call with the runsolver
            and executes it
            Args:
                list of target cmd (from getCommandLineArgs)
        '''
        random_id = random.randint(0, 1000000)
        self._watcher_file = NamedTemporaryFile(suffix=".log",
                                                prefix="watcher-%d-" %
                                                (random_id),
                                                dir=self._tmp_dir,
                                                delete=False)
        self._solver_file = NamedTemporaryFile(suffix=".log",
                                               prefix="solver-%d-" %
                                               (random_id),
                                               dir=self._tmp_dir,
                                               delete=False)

        runsolver_cmd = []
        if self._runsolver != "None":
            runsolver_cmd = [
                self._runsolver, "-M", self._mem_limit, "-C", self._cutoff,
                "-w", self._watcher_file.name, "-o", self._solver_file.name
            ]

        runsolver_cmd.extend(target_cmd)
        #for debugging
        self.print_d("Calling runsolver. Command-line:")
        self.print_d(" ".join(map(str, runsolver_cmd)))

        # run
        try:
            if self._runsolver != "None":
                if "\"" in runsolver_cmd:  # if there are quotes in the call, we cannot split it individual list elements. we have to call it via shell as a string; problematic solver: SparrowToRiss
                    runsolver_cmd = " ".join(map(str, runsolver_cmd))
                    io = Popen(runsolver_cmd,
                               shell=True,
                               preexec_fn=os.setpgrp,
                               universal_newlines=True)
                else:
                    io = Popen(map(str, runsolver_cmd),
                               shell=False,
                               preexec_fn=os.setpgrp,
                               universal_newlines=True)
            else:
                io = Popen(map(str, runsolver_cmd),
                           stdout=self._solver_file,
                           shell=False,
                           preexec_fn=os.setpgrp,
                           universal_newlines=True)
            self._subprocesses.append(io)
            io.wait()
            self._subprocesses.remove(io)
            if io.stdout:
                io.stdout.flush()
        except OSError:
            self._ta_status = "ABORT"
            self._ta_misc = "execution failed: %s" % (" ".join(
                map(str, runsolver_cmd)))
            self._exit_code = 1
            sys.exit(1)

        self._solver_file.seek(0)

    def float_regex(self):
        return '[+-]?\d+(?:\.\d+)?(?:[eE][+-]\d+)?'

    def read_runsolver_output(self):
        '''
            reads self._watcher_file, 
            extracts runtime
            and returns if memout or timeout found
        '''
        if self._runsolver == "None":
            self._ta_exit_code = 0
            return

        self.print_d("Reading runsolver output from %s" %
                     (self._watcher_file.name))
        data = str(self._watcher_file.read())

        if (re.search('runsolver_max_cpu_time_exceeded', data)
                or re.search('Maximum CPU time exceeded', data)):
            self._ta_status = "TIMEOUT"

        if (re.search('runsolver_max_memory_limit_exceeded', data)
                or re.search('Maximum VSize exceeded', data)):
            self._ta_status = "TIMEOUT"
            self._ta_misc = "memory limit was exceeded"

        cpu_pattern1 = re.compile('runsolver_cputime: (%s)' %
                                  (self.float_regex()))
        cpu_match1 = re.search(cpu_pattern1, data)

        cpu_pattern2 = re.compile('CPU time \\(s\\): (%s)' %
                                  (self.float_regex()))
        cpu_match2 = re.search(cpu_pattern2, data)

        if (cpu_match1):
            self._ta_runtime = float(cpu_match1.group(1))
        if (cpu_match2):
            self._ta_runtime = float(cpu_match2.group(1))

        exitcode_pattern = re.compile('Child status: ([0-9]+)')
        exitcode_match = re.search(exitcode_pattern, data)

        if (exitcode_match):
            self._ta_exit_code = int(exitcode_match.group(1))

    def print_result_string(self):

        if self.args and self.args.log:
            #if not os.path.isfile("target_algo_runs.csv"):
            #    with open("target_algo_runs.csv", "a") as fp:
            #        fp.write("instance,seed,status,performance,config,[misc]\n")
            with open("target_algo_runs.json", "a") as fp:
                out_dict = {
                    "instance": self._instance,
                    "seed": self._seed,
                    "status": self._ta_status,
                    "time": self._ta_runtime,
                    "config": self._config_dict,
                    "misc": self._ta_misc
                }
                json.dump(out_dict, fp)
                fp.write("\n")
                fp.flush()

        sys.stdout.write(
            "Result for ParamILS: %s, %s, %s, %s, %s" %
            (self._ta_status, str(self._ta_runtime), str(
                self._ta_runlength), str(self._ta_quality), str(self._seed)))
        if (len(self._ta_misc) > 0):
            sys.stdout.write(", %s" % (self._ta_misc))
        print('')
        sys.stdout.flush()

    def cleanup(self):
        '''
            cleanup if error occurred or external signal handled
        '''
        if (len(self._subprocesses) > 0):
            print("killing the target run!")
            try:
                for sub in self._subprocesses:
                    #sub.terminate()
                    Popen(["pkill", "-TERM", "-P", str(sub.pid)])
                    self.print_d("Wait %d seconds ..." % (self._DELAY2KILL))
                    time.sleep(self._DELAY2KILL)
                    if sub.returncode is None:  # still running
                        sub.kill()

                self.print_d(
                    "done... If anything in the subprocess tree fork'd a new process group, we may not have caught everything..."
                )
                self._ta_misc = "forced to exit by signal or keyboard interrupt."
                self._ta_runtime = self._cutoff
            except (OSError, KeyboardInterrupt, SystemExit):
                self._ta_misc = "forced to exit by multiple signals/interrupts."
                self._ta_runtime = self._cutoff

        if (self._ta_status is "ABORT" or self._ta_status is "CRASHED"):
            if (len(self._ta_misc) == 0):
                if self._ta_exit_code:
                    self._ta_misc = 'Problem with run. Exit code was %d.' % (
                        self._ta_exit_code)
                else:
                    self._ta_misc = 'Problem with run. Exit code was N/A.'

            if (self._watcher_file and self._solver_file):
                self._ta_misc = self._ta_misc + '; Preserving runsolver output at %s - preserving target algorithm output at %s' % (
                    self._watcher_file.name or "<none>", self._solver_file.name
                    or "<none>")

        try:
            if (self._watcher_file):
                self._watcher_file.close()
            if (self._solver_file):
                self._solver_file.close()

            if (self._ta_status is not "ABORT"
                    and self._ta_status is not "CRASHED"):
                os.remove(self._watcher_file.name)
                os.remove(self._solver_file.name)

            if self._tmp_dir_algo:
                shutil.rmtree(self._tmp_dir_algo)

        except (OSError, KeyboardInterrupt, SystemExit):
            self._ta_misc = "problems removing temporary files during cleanup."
        except AttributeError:
            pass  #in internal mode, these files are not generated

        if self._ta_status is "EXTERNALKILL":
            self._ta_status = "CRASHED"
            self._exit_code = 3

    def get_command_line_args(self, runargs, config):
        '''
        Returns the command call list containing arguments to execute the implementing subclass' solver.
        The default implementation delegates to get_command_line_args_ext. If this is not implemented, a
        NotImplementedError will be raised.
    
        Args:
            runargs: a map of any non-configuration arguments required for the execution of the solver.
            config: a mapping from parameter name (with prefix) to parameter value.
        Returns:
            A command call list to execute a target algorithm.
        '''
        raise NotImplementedError()

    def get_command_line_args_ext(self, runargs, config, ext_call):
        '''
        When production of the target algorithm is done from a source other than python,
        override this method to return a command call list to execute whatever you need to produce the command line.

        Args:
            runargs: a map of any non-configuration arguments required for the execution of the solver.
            config: a mapping from parameter name (with prefix) to parameter value.
            ext_call: string to call external program to get callstring of target algorithm
        Returns:
            A command call list to execute the command producing a single line of output containing the solver command string
        '''
        callstring_in = NamedTemporaryFile(suffix=".csv",
                                           prefix="callstring",
                                           dir=self._tmp_dir,
                                           delete=False)
        callstring_in.write("%s\n" % (runargs["instance"]))
        callstring_in.write("%d\n" % (runargs["seed"]))
        for name, value in config.items():
            callstring_in.write("%s,%s\n" % (name, value))
        callstring_in.flush()

        cmd = ext_call.split(" ")
        cmd.append(callstring_in.name)
        self.print_d(" ".join(cmd))
        try:
            io = Popen(cmd,
                       shell=False,
                       preexec_fn=os.setpgrp,
                       stdout=PIPE,
                       universal_newlines=True)
            self._subprocesses.append(io)
            out_, _ = io.communicate()
            self._subprocesses.remove(io)
        except OSError:
            self._ta_misc = "failed to run external program for output parsing : %s" % (
                " ".join(cmd))
            self._ta_runtime = self._cutoff
            self._exit_code = 2
            sys.exit(2)
        if not out_:
            self._ta_misc = "external program for output parsing yielded empty output: %s" % (
                " ".join(cmd))
            self._ta_runtime = self._cutoff
            self._exit_code = 2
            sys.exit(2)
        callstring_in.close()
        os.remove(callstring_in.name)
        self._instance = runargs["instance"]
        return out_.strip('\n\r\b')

    def process_results(self, filepointer, out_args):
        '''
        Parse a results file to extract the run's status (SUCCESS/CRASHED/etc) and other optional results.
    
        Args:
            filepointer: a pointer to the file containing the solver execution standard out.
            exit_code : exit code of target algorithm
        Returns:
            A map containing the standard AClib run results. The current standard result map as of AClib 2.06 is:
            {
                "status" : <"SAT"/"UNSAT"/"TIMEOUT"/"CRASHED"/"ABORT">,
                "runtime" : <runtime of target algrithm>,
                "quality" : <a domain specific measure of the quality of the solution [optional]>,
                "misc" : <a (comma-less) string that will be associated with the run [optional]>
            }
            ATTENTION: The return values will overwrite the measured results of the runsolver (if runsolver was used). 
        '''
        raise NotImplementedError()

    def process_results_ext(self, filepointer, out_args, ext_call):
        '''
        Args:
            filepointer: a pointer to the file containing the solver execution standard out.
            exit_code : exit code of target algorithm
        Returns:
            A map containing the standard AClib run results. The current standard result map as of AClib 2.06 is:
            {
                "status" : <"SAT"/"UNSAT"/"TIMEOUT"/"CRASHED"/"ABORT">,
                "quality" : <a domain specific measure of the quality of the solution [optional]>,
                "misc" : <a (comma-less) string that will be associated with the run [optional]>
            }
        '''

        cmd = ext_call.split(" ")
        cmd.append(filepointer.name)
        self.print_d(" ".join(cmd))
        try:
            io = Popen(cmd,
                       shell=False,
                       preexec_fn=os.setpgrp,
                       stdout=PIPE,
                       universal_newlines=True)
            self._subprocesses.append(io)
            out_, _ = io.communicate()
            self._subprocesses.remove(io)
        except OSError:
            self._ta_misc = "failed to run external program for output parsing"
            self._ta_runtime = self._cutoff
            self._exit_code = 2
            sys.exit(2)

        result_map = {}
        for line in out_.split("\n"):
            if line.startswith("status:"):
                result_map["status"] = line.split(":")[1].strip(" ")
            elif line.startswith("quality:"):
                result_map["quality"] = line.split(":")[1].strip(" ")
            elif line.startswith("misc:"):
                result_map["misc"] = line.split(":")[1]

        return result_map
Exemple #41
0
    def run(self):
        """Method which explicitely runs LAMMPS."""

        self.calls += 1

        # set LAMMPS command from environment variable
        if 'LAMMPS_COMMAND' in os.environ:
            lammps_cmd_line = shlex.split(os.environ['LAMMPS_COMMAND'])
            if len(lammps_cmd_line) == 0:
                self.clean()
                raise RuntimeError('The LAMMPS_COMMAND environment variable '
                                   'must not be empty')
            # want always an absolute path to LAMMPS binary when calling from self.dir
            lammps_cmd_line[0] = os.path.abspath(lammps_cmd_line[0])

        else:
            self.clean()
            raise RuntimeError(
                'Please set LAMMPS_COMMAND environment variable')
        if 'LAMMPS_OPTIONS' in os.environ:
            lammps_options = shlex.split(os.environ['LAMMPS_OPTIONS'])
        else:
            lammps_options = shlex.split('-echo log -screen none')

        # change into subdirectory for LAMMPS calculations
        cwd = os.getcwd()
        os.chdir(self.tmp_dir)

        # setup file names for LAMMPS calculation
        label = '%s%06d' % (self.label, self.calls)
        lammps_in = uns_mktemp(prefix='in_' + label, dir=self.tmp_dir)
        lammps_log = uns_mktemp(prefix='log_' + label, dir=self.tmp_dir)
        lammps_trj_fd = NamedTemporaryFile(prefix='trj_' + label,
                                           dir=self.tmp_dir,
                                           delete=(not self.keep_tmp_files))
        lammps_trj = lammps_trj_fd.name
        if self.no_data_file:
            lammps_data = None
        else:
            lammps_data_fd = NamedTemporaryFile(
                prefix='data_' + label,
                dir=self.tmp_dir,
                delete=(not self.keep_tmp_files))
            self.write_lammps_data(lammps_data=lammps_data_fd)
            lammps_data = lammps_data_fd.name
            lammps_data_fd.flush()

        # see to it that LAMMPS is started
        if not self._lmp_alive():
            # Attempt to (re)start lammps
            self._lmp_handle = Popen(lammps_cmd_line + lammps_options +
                                     ['-log', '/dev/stdout'],
                                     stdin=PIPE,
                                     stdout=PIPE)
        lmp_handle = self._lmp_handle

        # Create thread reading lammps stdout (for reference, if requested,
        # also create lammps_log, although it is never used)
        if self.keep_tmp_files:
            lammps_log_fd = open(lammps_log, 'w')
            fd = special_tee(lmp_handle.stdout, lammps_log_fd)
        else:
            fd = lmp_handle.stdout
        thr_read_log = Thread(target=self.read_lammps_log, args=(fd, ))
        thr_read_log.start()

        # write LAMMPS input (for reference, also create the file lammps_in,
        # although it is never used)
        if self.keep_tmp_files:
            lammps_in_fd = open(lammps_in, 'w')
            fd = special_tee(lmp_handle.stdin, lammps_in_fd)
        else:
            fd = lmp_handle.stdin
        self.write_lammps_in(lammps_in=fd,
                             lammps_trj=lammps_trj,
                             lammps_data=lammps_data)

        if self.keep_tmp_files:
            lammps_in_fd.close()

        # Wait for log output to be read (i.e., for LAMMPS to finish)
        # and close the log file if there is one
        thr_read_log.join()
        if self.keep_tmp_files:
            lammps_log_fd.close()

        if not self.keep_alive:
            self._lmp_end()

        exitcode = lmp_handle.poll()
        if exitcode and exitcode != 0:
            cwd = os.getcwd()
            raise RuntimeError('LAMMPS exited in %s with exit code: %d.' %\
                                   (cwd,exitcode))

        # A few sanity checks
        if len(self.thermo_content) == 0:
            raise RuntimeError('Failed to retreive any thermo_style-output')
        if int(self.thermo_content[-1]['atoms']) != len(self.atoms):
            # This obviously shouldn't happen, but if prism.fold_...() fails, it could
            raise RuntimeError('Atoms have gone missing')

        self.read_lammps_trj(lammps_trj=lammps_trj,
                             set_atoms=True,
                             set_cell=True)
        lammps_trj_fd.close()
        if not self.no_data_file:
            lammps_data_fd.close()

        os.chdir(cwd)
Exemple #42
0
def get_readable_fileobj(name_or_obj,
                         encoding=None,
                         cache=False,
                         show_progress=True,
                         remote_timeout=None):
    """
    Given a filename, pathlib.Path object or a readable file-like object, return a context
    manager that yields a readable file-like object.

    This supports passing filenames, URLs, and readable file-like objects,
    any of which can be compressed in gzip, bzip2 or lzma (xz) if the
    appropriate compression libraries are provided by the Python installation.

    Notes
    -----

    This function is a context manager, and should be used for example
    as::

        with get_readable_fileobj('file.dat') as f:
            contents = f.read()

    Parameters
    ----------
    name_or_obj : str or file-like object
        The filename of the file to access (if given as a string), or
        the file-like object to access.

        If a file-like object, it must be opened in binary mode.

    encoding : str, optional
        When `None` (default), returns a file-like object with a
        ``read`` method that returns `str` (``unicode``) objects, using
        `locale.getpreferredencoding` as an encoding.  This matches
        the default behavior of the built-in `open` when no ``mode``
        argument is provided.

        When ``'binary'``, returns a file-like object where its ``read``
        method returns `bytes` objects.

        When another string, it is the name of an encoding, and the
        file-like object's ``read`` method will return `str` (``unicode``)
        objects, decoded from binary using the given encoding.

    cache : bool, optional
        Whether to cache the contents of remote URLs.

    show_progress : bool, optional
        Whether to display a progress bar if the file is downloaded
        from a remote server.  Default is `True`.

    remote_timeout : float
        Timeout for remote requests in seconds (default is the configurable
        `astropy.utils.data.Conf.remote_timeout`, which is 3s by default)

    Returns
    -------
    file : readable file-like object
    """

    # close_fds is a list of file handles created by this function
    # that need to be closed.  We don't want to always just close the
    # returned file handle, because it may simply be the file handle
    # passed in.  In that case it is not the responsibility of this
    # function to close it: doing so could result in a "double close"
    # and an "invalid file descriptor" exception.
    PATH_TYPES = (str, pathlib.Path)

    close_fds = []
    delete_fds = []

    if remote_timeout is None:
        # use configfile default
        remote_timeout = conf.remote_timeout

    # Get a file object to the content
    if isinstance(name_or_obj, PATH_TYPES):
        # name_or_obj could be a Path object if pathlib is available
        name_or_obj = str(name_or_obj)

        is_url = _is_url(name_or_obj)
        if is_url:
            name_or_obj = download_file(name_or_obj,
                                        cache=cache,
                                        show_progress=show_progress,
                                        timeout=remote_timeout)
        fileobj = io.FileIO(name_or_obj, 'r')
        if is_url and not cache:
            delete_fds.append(fileobj)
        close_fds.append(fileobj)
    else:
        fileobj = name_or_obj

    # Check if the file object supports random access, and if not,
    # then wrap it in a BytesIO buffer.  It would be nicer to use a
    # BufferedReader to avoid reading loading the whole file first,
    # but that is not compatible with streams or urllib2.urlopen
    # objects on Python 2.x.
    if not hasattr(fileobj, 'seek'):
        fileobj = io.BytesIO(fileobj.read())

    # Now read enough bytes to look at signature
    signature = fileobj.read(4)
    fileobj.seek(0)

    if signature[:3] == b'\x1f\x8b\x08':  # gzip
        import struct
        try:
            import gzip
            fileobj_new = gzip.GzipFile(fileobj=fileobj, mode='rb')
            fileobj_new.read(1)  # need to check that the file is really gzip
        except (OSError, EOFError, struct.error):  # invalid gzip file
            fileobj.seek(0)
            fileobj_new.close()
        else:
            fileobj_new.seek(0)
            fileobj = fileobj_new
    elif signature[:3] == b'BZh':  # bzip2
        try:
            import bz2
        except ImportError:
            for fd in close_fds:
                fd.close()
            raise ValueError(
                ".bz2 format files are not supported since the Python "
                "interpreter does not include the bz2 module")
        try:
            # bz2.BZ2File does not support file objects, only filenames, so we
            # need to write the data to a temporary file
            with NamedTemporaryFile("wb", delete=False) as tmp:
                tmp.write(fileobj.read())
                tmp.close()
                fileobj_new = bz2.BZ2File(tmp.name, mode='rb')
            fileobj_new.read(1)  # need to check that the file is really bzip2
        except OSError:  # invalid bzip2 file
            fileobj.seek(0)
            fileobj_new.close()
            # raise
        else:
            fileobj_new.seek(0)
            close_fds.append(fileobj_new)
            fileobj = fileobj_new
    elif signature[:3] == b'\xfd7z':  # xz
        try:
            import lzma
            fileobj_new = lzma.LZMAFile(fileobj, mode='rb')
            fileobj_new.read(1)  # need to check that the file is really xz
        except ImportError:
            for fd in close_fds:
                fd.close()
            raise ValueError(
                ".xz format files are not supported since the Python "
                "interpreter does not include the lzma module.")
        except (OSError, EOFError) as e:  # invalid xz file
            fileobj.seek(0)
            fileobj_new.close()
            # should we propagate this to the caller to signal bad content?
            # raise ValueError(e)
        else:
            fileobj_new.seek(0)
            fileobj = fileobj_new

    # By this point, we have a file, io.FileIO, gzip.GzipFile, bz2.BZ2File
    # or lzma.LZMAFile instance opened in binary mode (that is, read
    # returns bytes).  Now we need to, if requested, wrap it in a
    # io.TextIOWrapper so read will return unicode based on the
    # encoding parameter.

    needs_textio_wrapper = encoding != 'binary'

    if needs_textio_wrapper:
        # A bz2.BZ2File can not be wrapped by a TextIOWrapper,
        # so we decompress it to a temporary file and then
        # return a handle to that.
        try:
            import bz2
        except ImportError:
            pass
        else:
            if isinstance(fileobj, bz2.BZ2File):
                tmp = NamedTemporaryFile("wb", delete=False)
                data = fileobj.read()
                tmp.write(data)
                tmp.close()
                delete_fds.append(tmp)

                fileobj = io.FileIO(tmp.name, 'r')
                close_fds.append(fileobj)

        fileobj = io.BufferedReader(fileobj)
        fileobj = io.TextIOWrapper(fileobj, encoding=encoding)

        # Ensure that file is at the start - io.FileIO will for
        # example not always be at the start:
        # >>> import io
        # >>> f = open('test.fits', 'rb')
        # >>> f.read(4)
        # 'SIMP'
        # >>> f.seek(0)
        # >>> fileobj = io.FileIO(f.fileno())
        # >>> fileobj.tell()
        # 4096L

        fileobj.seek(0)

    try:
        yield fileobj
    finally:
        for fd in close_fds:
            fd.close()
        for fd in delete_fds:
            os.remove(fd.name)
Exemple #43
0
    def _run_program(self, bin, fastafile, savedir, params=None):

        default_params = {
            "width": 10,
            "background": "",
            "single": False,
            "number": 10
        }
        if params is not None:
            default_params.update(params)

        background = default_params['background']
        width = default_params['width']
        number = default_params['number']

        if not background:
            if default_params["organism"]:
                org = default_params["organism"]
                background = os.path.join(
                    self.config.get_bg_dir(),
                    "{}.{}.bg".format(org, "MotifSampler"))
            else:
                raise Exception, "No background specified for {}".format(
                    self.name)

        fastafile = os.path.abspath(fastafile)
        savedir = os.path.abspath(savedir)

        tmp = NamedTemporaryFile(dir=self.tmpdir)
        pwmfile = tmp.name

        tmp2 = NamedTemporaryFile(dir=self.tmpdir)
        outfile = tmp2.name

        strand = 1
        if default_params["single"]:
            strand = 0

        # TODO: test organism
        cmd = "%s -f %s -b %s -m %s -w %s -n %s -o %s -s %s > /dev/null 2>&1" % (
            bin, fastafile, background, pwmfile, width, number, outfile,
            strand)
        #print cmd
        #p = Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
        #stdout, stderr = p.communicate()

        stdout, stderr = "", ""
        p = Popen(cmd, shell=True)
        p.wait()

        motifs = []
        #if os.path.exists(pwmfile):
        #    motifs = self.parse(open(pwmfile))
        if os.path.exists(outfile):
            motifs = self.parse_out(open(outfile))

        # remove temporary files
        tmp.close()
        tmp2.close()

        for motif in motifs:
            motif.id = "%s_%s" % (self.name, motif.id)

        return motifs, stdout, stderr
Exemple #44
0
def main():
    parser = argparse.ArgumentParser(
        description=
        'mkmodel.py: compute IBM-1 translation probabilties using eflomal, the efficient low-memory aligner'
    )
    parser.add_argument('-v',
                        '--verbose',
                        dest='verbose',
                        action="count",
                        default=0,
                        help='Enable verbose output')
    parser.add_argument('--debug',
                        dest='debug',
                        action='store_true',
                        help='Enable gdb debugging of eflomal binary')
    parser.add_argument('--no-lowercase',
                        dest='lowercase',
                        action='store_false',
                        default=True,
                        help='Do not lowercase input text')
    parser.add_argument('--overwrite',
                        dest='overwrite',
                        action='store_true',
                        help='Overwrite existing output files')
    parser.add_argument('--null-prior',
                        dest='null_prior',
                        default=0.2,
                        metavar='X',
                        type=float,
                        help='Prior probability of NULL alignment')
    parser.add_argument(
        '-m',
        '--model',
        dest='model',
        default=3,
        metavar='N',
        type=int,
        help='Model (1 = IBM1, 2 = IBM1+HMM, 3 = IBM1+HMM+fertility)')
    parser.add_argument('--source-prefix',
                        dest='source_prefix_len',
                        default=0,
                        metavar='N',
                        type=int,
                        help='Length of prefix for stemming (source)')
    parser.add_argument('--source-suffix',
                        dest='source_suffix_len',
                        default=0,
                        metavar='N',
                        type=int,
                        help='Length of suffix for stemming (source)')
    parser.add_argument('--target-prefix',
                        dest='target_prefix_len',
                        default=0,
                        metavar='N',
                        type=int,
                        help='Length of prefix for stemming (target)')
    parser.add_argument('--target-suffix',
                        dest='target_suffix_len',
                        default=0,
                        metavar='N',
                        type=int,
                        help='Length of suffix for stemming (target)')
    parser.add_argument('-l',
                        '--length',
                        dest='length',
                        default=1.0,
                        metavar='X',
                        type=float,
                        help='Relative number of sampling iterations')
    parser.add_argument('-1',
                        '--ibm1-iters',
                        dest='iters1',
                        default=None,
                        metavar='X',
                        type=int,
                        help='Number of IBM1 iterations (overrides --length)')
    parser.add_argument('-2',
                        '--hmm-iters',
                        dest='iters2',
                        default=None,
                        metavar='X',
                        type=int,
                        help='Number of HMM iterations (overrides --length)')
    parser.add_argument(
        '-3',
        '--fert-iters',
        dest='iters3',
        default=None,
        metavar='X',
        type=int,
        help='Number of HMM+fertility iterations (overrides --length)')
    parser.add_argument('--n-samplers',
                        dest='n_samplers',
                        default=3,
                        metavar='X',
                        type=int,
                        help='Number of independent samplers to run')
    parser.add_argument('-s',
                        '--source',
                        dest='source_filename',
                        type=str,
                        metavar='filename',
                        help='Source text filename',
                        required=True)
    parser.add_argument('-t',
                        '--target',
                        dest='target_filename',
                        type=str,
                        metavar='filename',
                        help='Target text filename',
                        required=True)
    parser.add_argument(
        '-f',
        '--forward-probabilities',
        dest='p_filename_fwd',
        type=str,
        metavar='filename',
        help=
        'Filename to write forward direction probabilities to, as pickle dump')
    parser.add_argument(
        '-r',
        '--reverse-probabilities',
        dest='p_filename_rev',
        type=str,
        metavar='filename',
        help=
        'Filename to write reverse direction probabilities to, as pickle dump')
    parser.add_argument(
        '-F',
        '--forward-probabilities-human',
        dest='p_filename_fwd_h',
        type=str,
        metavar='filename',
        help=
        'Filename to write forward direction probabilities to, as human readable dump'
    )
    parser.add_argument(
        '-R',
        '--reverse-probabilities-human',
        dest='p_filename_rev_h',
        type=str,
        metavar='filename',
        help=
        'Filename to write reverse direction probabilities to, as human readable dump'
    )

    args = parser.parse_args()

    logger = Logger(args.verbose)

    if args.p_filename_fwd is None and args.p_filename_rev is None:
        print('ERROR: no file to save probabilities (-f/-r), will do nothing.',
              file=sys.stderr,
              flush=True)
        sys.exit(1)

    for filename in (args.source_filename, args.target_filename):
        if not os.path.exists(filename):
            print('ERROR: input file %s does not exist!' % filename,
                  file=sys.stderr,
                  flush=True)
            sys.exit(1)

    for filename in (args.p_filename_fwd, args.p_filename_rev):
        if (not args.overwrite) and (filename is not None) \
                and os.path.exists(filename):
            print('ERROR: output file %s exists, will not overwrite!' % \
                    filename,
                  file=sys.stderr, flush=True)
            sys.exit(1)

    if args.verbose:
        print('Reading source text from %s...' % args.source_filename,
              file=sys.stderr,
              flush=True)
    with xopen(args.source_filename, 'r', encoding='utf-8') as f:
        src_sents, src_index = read_text(f, args.lowercase,
                                         args.source_prefix_len,
                                         args.source_suffix_len)
        n_src_sents = len(src_sents)
        src_voc_size = len(src_index)
        src_index = None
        srcf = NamedTemporaryFile('wb')
        write_text(srcf, tuple(src_sents), src_voc_size)
        src_sents = None

    if args.verbose:
        print('Reading target text from %s...' % args.target_filename,
              file=sys.stderr,
              flush=True)
    with xopen(args.target_filename, 'r', encoding='utf-8') as f:
        trg_sents, trg_index = read_text(f, args.lowercase,
                                         args.target_prefix_len,
                                         args.target_suffix_len)
        trg_voc_size = len(trg_index)
        n_trg_sents = len(trg_sents)
        trg_index = None
        trgf = NamedTemporaryFile('wb')
        write_text(trgf, tuple(trg_sents), trg_voc_size)
        trg_sents = None

    if n_src_sents != n_trg_sents:
        print('ERROR: number of sentences differ in input files (%d vs %d)' %
              (n_src_sents, n_trg_sents),
              file=sys.stderr,
              flush=True)
        sys.exit(1)

    iters = (args.iters1, args.iters2, args.iters3)
    if any(x is None for x in iters[:args.model]):
        iters = None

    if args.verbose:
        print('Aligning %d sentences...' % n_src_sents,
              file=sys.stderr,
              flush=True)

    fwd_alignment_file = NamedTemporaryFile('w')
    rev_alignment_file = NamedTemporaryFile('w')

    align(srcf.name,
          trgf.name,
          links_filename_fwd=fwd_alignment_file.name,
          links_filename_rev=rev_alignment_file.name,
          statistics_filename=None,
          scores_filename=None,
          model=args.model,
          n_iterations=iters,
          n_samplers=args.n_samplers,
          quiet=not args.verbose,
          rel_iterations=args.length,
          null_prior=args.null_prior,
          use_gdb=args.debug)

    srcf.close()
    trgf.close()

    # split and, if requested, lowercase tokens
    logger.info("Preprocessing sentences for probability estimation...")
    with xopen(args.source_filename, 'r',
               encoding='utf-8') as fsrc, xopen(args.target_filename,
                                                'r',
                                                encoding='utf-8') as ftgt:
        src_sents = preprocess(fsrc.readlines(), args.lowercase)
        trg_sents = preprocess(ftgt.readlines(), args.lowercase)

    # extract token --> index hash table
    logger.info("Extracting vocabulary...")
    voc_s = make_voc(src_sents)
    voc_t = make_voc(trg_sents)

    if args.p_filename_fwd is not None:
        logger.info("Estimating forward counts...")
        counts, s_counts = compute_counts_fwd(voc_s, voc_t, src_sents,
                                              trg_sents,
                                              fwd_alignment_file.name,
                                              args.lowercase)
        logger.info("Estimating forward probabilities...")
        p = compute_p(voc_s, voc_t, counts, s_counts)
        logger.info("Saving forward probabilities...")
        model = IBM1(p, voc_s, voc_t)
        save_p(model, args.p_filename_fwd)
        if args.p_filename_fwd_h is not None:
            with xopen(args.p_filename_fwd_h, "w") as f:
                model.dump(f)

    if args.p_filename_rev is not None:
        logger.info("Estimating reverse counts...")
        counts, t_counts = compute_counts_rev(voc_s, voc_t, src_sents,
                                              trg_sents,
                                              rev_alignment_file.name,
                                              args.lowercase)
        logger.info("Estimating reverse probabilities...")
        p = compute_p(voc_t, voc_s, counts, t_counts)
        logger.info("Saving reverse probabilities...")
        model = IBM1(p, voc_t, voc_s)
        save_p(model, args.p_filename_rev)
        if args.p_filename_rev_h is not None:
            with xopen(args.p_filename_rev_h, "w") as f:
                model.dump(f)

    fwd_alignment_file.close()
    rev_alignment_file.close()
Exemple #45
0
def write_mesos_cli_config(config):
    mesos_cli_config_file = NamedTemporaryFile(delete=False)
    mesos_cli_config_file.write(json.dumps(config))
    mesos_cli_config_file.close()
    return mesos_cli_config_file.name
def mms_load_data(trange=['2015-10-16', '2015-10-17'],
                  probe='1',
                  data_rate='srvy',
                  level='l2',
                  instrument='fgm',
                  datatype='',
                  varformat=None,
                  prefix='',
                  suffix='',
                  get_support_data=False,
                  time_clip=False,
                  no_update=False,
                  center_measurement=False,
                  available=False,
                  notplot=False):
    """
    This function loads MMS data into pyTplot variables
    """

    if not isinstance(probe, list): probe = [probe]
    if not isinstance(data_rate, list): data_rate = [data_rate]
    if not isinstance(level, list): level = [level]
    if not isinstance(datatype, list): datatype = [datatype]

    probe = [str(p) for p in probe]

    # allows the user to pass in trange as list of datetime objects
    if type(trange[0]) == datetime and type(trange[1]) == datetime:
        trange = [
            time_string(trange[0].timestamp()),
            time_string(trange[1].timestamp())
        ]

    # allows the user to pass in trange as a list of floats (unix times)
    if isinstance(trange[0], float):
        trange[0] = time_string(trange[0])
    if isinstance(trange[1], float):
        trange[1] = time_string(trange[1])

    start_date = parse(trange[0]).strftime(
        '%Y-%m-%d')  # need to request full day, then parse out later
    end_date = parse(time_string(time_double(trange[1]) - 0.1)).strftime(
        '%Y-%m-%d-%H-%M-%S'
    )  # -1 second to avoid getting data for the next day

    download_only = CONFIG['download_only']

    no_download = False
    if no_update or CONFIG['no_download']: no_download = True

    user = None
    if not no_download:
        sdc_session, user = mms_login_lasp()

    out_files = []
    available_files = []

    for prb in probe:
        for drate in data_rate:
            for lvl in level:
                for dtype in datatype:
                    if user is None:
                        url = 'https://lasp.colorado.edu/mms/sdc/public/files/api/v1/file_info/science?start_date=' + start_date + '&end_date=' + end_date + '&sc_id=mms' + prb + '&instrument_id=' + instrument + '&data_rate_mode=' + drate + '&data_level=' + lvl
                    else:
                        url = 'https://lasp.colorado.edu/mms/sdc/sitl/files/api/v1/file_info/science?start_date=' + start_date + '&end_date=' + end_date + '&sc_id=mms' + prb + '&instrument_id=' + instrument + '&data_rate_mode=' + drate + '&data_level=' + lvl

                    if dtype != '':
                        url = url + '&descriptor=' + dtype

                    if CONFIG['debug_mode']: logging.info('Fetching: ' + url)

                    if no_download == False:
                        # query list of available files
                        try:
                            with warnings.catch_warnings():
                                warnings.simplefilter("ignore",
                                                      category=ResourceWarning)
                                http_json = sdc_session.get(
                                    url, verify=True).json()

                            if CONFIG['debug_mode']:
                                logging.info(
                                    'Filtering the results down to your trange'
                                )

                            files_in_interval = mms_files_in_interval(
                                http_json['files'], trange)

                            if available:
                                for file in files_in_interval:
                                    logging.info(
                                        file['file_name'] + ' (' + str(
                                            np.round(file['file_size'] /
                                                     (1024. * 1024),
                                                     decimals=1)) + ' MB)')
                                    available_files.append(file['file_name'])
                                continue

                            for file in files_in_interval:
                                file_date = parse(file['timetag'])
                                if dtype == '':
                                    out_dir = os.sep.join([
                                        CONFIG['local_data_dir'], 'mms' + prb,
                                        instrument, drate, lvl,
                                        file_date.strftime('%Y'),
                                        file_date.strftime('%m')
                                    ])
                                else:
                                    out_dir = os.sep.join([
                                        CONFIG['local_data_dir'], 'mms' + prb,
                                        instrument, drate, lvl, dtype,
                                        file_date.strftime('%Y'),
                                        file_date.strftime('%m')
                                    ])

                                if drate.lower() == 'brst':
                                    out_dir = os.sep.join(
                                        [out_dir,
                                         file_date.strftime('%d')])

                                out_file = os.sep.join(
                                    [out_dir, file['file_name']])

                                if CONFIG['debug_mode']:
                                    logging.info('File: ' + file['file_name'] +
                                                 ' / ' + file['timetag'])

                                if os.path.exists(out_file) and str(
                                        os.stat(out_file).st_size) == str(
                                            file['file_size']):
                                    if not download_only:
                                        logging.info('Loading ' + out_file)
                                    out_files.append(out_file)
                                    continue

                                if user is None:
                                    download_url = 'https://lasp.colorado.edu/mms/sdc/public/files/api/v1/download/science?file=' + file[
                                        'file_name']
                                else:
                                    download_url = 'https://lasp.colorado.edu/mms/sdc/sitl/files/api/v1/download/science?file=' + file[
                                        'file_name']

                                logging.info('Downloading ' +
                                             file['file_name'] + ' to ' +
                                             out_dir)

                                with warnings.catch_warnings():
                                    warnings.simplefilter(
                                        "ignore", category=ResourceWarning)
                                    fsrc = sdc_session.get(download_url,
                                                           stream=True,
                                                           verify=True)
                                ftmp = NamedTemporaryFile(delete=False)

                                with open(ftmp.name, 'wb') as f:
                                    copyfileobj(fsrc.raw, f)

                                if not os.path.exists(out_dir):
                                    os.makedirs(out_dir)

                                # if the download was successful, copy to data directory
                                copy(ftmp.name, out_file)
                                out_files.append(out_file)
                                fsrc.close()
                                ftmp.close()
                        except requests.exceptions.ConnectionError:
                            # No/bad internet connection; try loading the files locally
                            logging.error('No internet connection!')

                    if out_files == []:
                        if not download_only:
                            logging.info('Searching for local files...')
                        out_files = mms_get_local_files(
                            prb, instrument, drate, lvl, dtype, trange)

    if not no_download:
        sdc_session.close()

    if available:
        return available_files

    if not download_only:
        out_files = sorted(out_files)

        new_variables = cdf_to_tplot(out_files,
                                     varformat=varformat,
                                     merge=True,
                                     get_support_data=get_support_data,
                                     prefix=prefix,
                                     suffix=suffix,
                                     center_measurement=center_measurement,
                                     notplot=notplot)

        if notplot:
            return new_variables

        if new_variables == []:
            logging.warning('No data loaded.')
            return

        if time_clip:
            for new_var in new_variables:
                tclip(new_var, trange[0], trange[1], suffix='')

        return new_variables
    else:
        return out_files
Exemple #47
0
class MockHub(JupyterHub):
    """Hub with various mock bits"""

    db_file = None
    last_activity_interval = 2
    log_datefmt = '%M:%S'
    external_certs = None
    log_level = 10

    def __init__(self, *args, **kwargs):
        if 'internal_certs_location' in kwargs:
            cert_location = kwargs['internal_certs_location']
            kwargs['external_certs'] = ssl_setup(cert_location, 'hub-ca')
        super().__init__(*args, **kwargs)

    @default('subdomain_host')
    def _subdomain_host_default(self):
        return os.environ.get('JUPYTERHUB_TEST_SUBDOMAIN_HOST', '')

    @default('bind_url')
    def _default_bind_url(self):
        if self.subdomain_host:
            port = urlparse(self.subdomain_host).port
        else:
            port = random_port()
        return 'http://127.0.0.1:%i/@/space%%20word/' % (port, )

    @default('ip')
    def _ip_default(self):
        return '127.0.0.1'

    @default('port')
    def _port_default(self):
        if self.subdomain_host:
            port = urlparse(self.subdomain_host).port
            if port:
                return port
        return random_port()

    @default('authenticator_class')
    def _authenticator_class_default(self):
        return MockPAMAuthenticator

    @default('spawner_class')
    def _spawner_class_default(self):
        return MockSpawner

    def init_signal(self):
        pass

    def load_config_file(self, *args, **kwargs):
        pass

    def init_tornado_application(self):
        """Instantiate the tornado Application object"""
        super().init_tornado_application()
        # reconnect tornado_settings so that mocks can update the real thing
        self.tornado_settings = self.users.settings = self.tornado_application.settings

    def init_services(self):
        # explicitly expire services before reinitializing
        # this only happens in tests because re-initialize
        # does not occur in a real instance
        for service in self.db.query(orm.Service):
            self.db.expire(service)
        return super().init_services()

    test_clean_db = Bool(True)

    def init_db(self):
        """Ensure we start with a clean user list"""
        super().init_db()
        if self.test_clean_db:
            for user in self.db.query(orm.User):
                self.db.delete(user)
            for group in self.db.query(orm.Group):
                self.db.delete(group)
            self.db.commit()

    @gen.coroutine
    def initialize(self, argv=None):
        self.pid_file = NamedTemporaryFile(delete=False).name
        self.db_file = NamedTemporaryFile()
        self.db_url = os.getenv('JUPYTERHUB_TEST_DB_URL') or self.db_file.name
        yield super().initialize([])

        # add an initial user
        user = self.db.query(orm.User).filter(orm.User.name == 'user').first()
        if user is None:
            user = orm.User(name='user')
            self.db.add(user)
            self.db.commit()

    def stop(self):
        super().stop()

        # run cleanup in a background thread
        # to avoid multiple eventloops in the same thread errors from asyncio

        def cleanup():
            asyncio.set_event_loop(asyncio.new_event_loop())
            loop = IOLoop.current()
            loop.run_sync(self.cleanup)
            loop.close()

        pool = ThreadPoolExecutor(1)
        f = pool.submit(cleanup)
        # wait for cleanup to finish
        f.result()
        pool.shutdown()

        # ignore the call that will fire in atexit
        self.cleanup = lambda: None
        self.db_file.close()

    @gen.coroutine
    def login_user(self, name):
        """Login a user by name, returning her cookies."""
        base_url = public_url(self)
        external_ca = None
        if self.internal_ssl:
            external_ca = self.external_certs['files']['ca']
        r = yield async_requests.post(
            base_url + 'hub/login',
            data={
                'username': name,
                'password': name,
            },
            allow_redirects=False,
            verify=external_ca,
        )
        r.raise_for_status()
        assert r.cookies
        return r.cookies
Exemple #48
0
    def _run_program(self, bin, fastafile, savedir="", params=None):
        #try:
        #    from gimmemotifs.mp import pool
        #except:
        #    pass

        default_params = {
            "analysis": "small",
            "organism": "hg18",
            "single": False,
            "parallel": True
        }
        if params is not None:
            default_params.update(params)

        organism = default_params["organism"]
        weeder_organism = ""
        weeder_organisms = {
            "hg18": "HS",
            "hg19": "HS",
            "mm9": "MM",
            "rn4": "RN",
            "dm3": "DM",
            "fr2": "FR",
            "danRer6": "DR",
            "danRer7": "DR",
            "galGal3": "GG",
            "ce3": "CE",
            "anoGam1": "AG",
            "yeast": "SC",
            "sacCer2": "SC",
            "xenTro2": "XT",
            "xenTro3": "XT"
        }
        if weeder_organisms.has_key(organism):
            weeder_organism = weeder_organisms[organism]
        else:
            return []

        weeder = bin
        adviser = weeder.replace("weederTFBS", "adviser")

        weeder_dir = bin.replace("weederTFBS.out", "")
        if self.is_configured():
            weeder_dir = self.dir()

        freq_files = os.path.join(weeder_dir, "FreqFiles")
        if not os.path.exists(freq_files):
            raise ValueError, "Can't find FreqFiles directory for Weeder"

        fastafile = os.path.abspath(fastafile)
        savedir = os.path.abspath(savedir)

        tmp = NamedTemporaryFile(dir=self.tmpdir)
        name = tmp.name
        tmp.close()
        shutil.copy(fastafile, name)
        fastafile = name

        current_path = os.getcwd()
        os.chdir(weeder_dir)

        coms = ((8, 2), (6, 1))

        strand = "-S"
        if default_params["single"]:
            strand = ""

        if default_params["analysis"] == "xl":
            coms = ((12, 4), (10, 3), (8, 2), (6, 1))
        elif default_params["analysis"] == "large":
            coms = ((10, 3), (8, 2), (6, 1))
        elif default_params["analysis"] == "medium":
            coms = ((10, 3), (8, 2), (6, 1))

        # TODO: test organism
        stdout = ""
        stderr = ""

        default_params["parallel"] = False
        if default_params["parallel"]:
            jobs = []
            #for (w,e) in coms:
            #    jobs.append(pool.apply_async(
            #        run_weeder_subset,
            #        (weeder, fastafile, w, e, weeder_organism, strand,)
            #        ))

            #for job in jobs:
            #    out,err = job.get()
            #    stdout += out
            #    stderr += err
        else:

            for (w, e) in coms:
                out, err = run_weeder_subset(weeder, fastafile, w, e,
                                             weeder_organism, strand)
                stdout += out
                stderr += err

        cmd = "%s %s" % (adviser, fastafile)
        #print cmd
        p = Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
        out, err = p.communicate()
        stdout += out
        stderr += err

        os.chdir(current_path)

        motifs = []
        if os.path.exists(fastafile + ".wee"):
            f = open(fastafile + ".wee")
            motifs = self.parse(f)
            f.close()

        for ext in [".wee", ".html", ".mix", ""]:
            if os.path.exists(fastafile + ext):
                os.unlink(fastafile + ext)

        return motifs, stdout, stderr
Exemple #49
0
class MockHub(JupyterHub):
    """Hub with various mock bits"""

    db_file = None
    confirm_no_ssl = True
    
    def _subdomain_host_default(self):
        return os.environ.get('JUPYTERHUB_TEST_SUBDOMAIN_HOST', '')
    
    def _ip_default(self):
        return '127.0.0.1'
    
    def _authenticator_class_default(self):
        return MockPAMAuthenticator
    
    def _spawner_class_default(self):
        return MockSpawner
    
    def init_signal(self):
        pass
    
    def start(self, argv=None):
        self.db_file = NamedTemporaryFile()
        self.db_url = 'sqlite:///' + self.db_file.name
        
        evt = threading.Event()
        
        @gen.coroutine
        def _start_co():
            assert self.io_loop._running
            # put initialize in start for SQLAlchemy threading reasons
            yield super(MockHub, self).initialize(argv=argv)
            # add an initial user
            user = orm.User(name='user')
            self.db.add(user)
            self.db.commit()
            yield super(MockHub, self).start()
            yield self.hub.server.wait_up(http=True)
            self.io_loop.add_callback(evt.set)
        
        def _start():
            self.io_loop = IOLoop()
            self.io_loop.make_current()
            self.io_loop.add_callback(_start_co)
            self.io_loop.start()
        
        self._thread = threading.Thread(target=_start)
        self._thread.start()
        ready = evt.wait(timeout=10)
        assert ready
    
    def stop(self):
        super().stop()
        self._thread.join()
        IOLoop().run_sync(self.cleanup)
        # ignore the call that will fire in atexit
        self.cleanup = lambda : None
        self.db_file.close()
    
    def login_user(self, name):
        base_url = public_url(self)
        r = requests.post(base_url + 'hub/login',
            data={
                'username': name,
                'password': name,
            },
            allow_redirects=False,
        )
        assert r.cookies
        return r.cookies
Exemple #50
0
def main(relative_root):
    git_version = run_shell_command("git version")
    if not git_version:
        print(
            "ERROR: Failed to run git. Make sure it is installed and in your PATH."
        )
        return False

    file_conflicts = run_shell_command(
        "git diff --name-only --diff-filter=U").split("\n")
    icon_conflicts = [
        path for path in file_conflicts if path[len(path) - 3::] == "dmi"
    ]

    for i in range(0, len(icon_conflicts)):
        print("[{}]: {}".format(i, icon_conflicts[i]))
    selection = input(
        "Choose icon files you want to fix (example: 1,3-5,12):\n")
    selection = selection.replace(" ", "")
    selection = selection.split(",")

    #shamelessly copied from mapmerger cli
    valid_indices = list()
    for m in selection:
        index_range = m.split("-")
        if len(index_range) == 1:
            index = string_to_num(index_range[0])
            if index >= 0 and index < len(icon_conflicts):
                valid_indices.append(index)
        elif len(index_range) == 2:
            index0 = string_to_num(index_range[0])
            index1 = string_to_num(index_range[1])
            if index0 >= 0 and index0 <= index1 and index1 < len(
                    icon_conflicts):
                valid_indices.extend(range(index0, index1 + 1))

    if not len(valid_indices):
        print("No icons selected, exiting.")
        sys.exit()

    print("Attempting to fix the following icon files:")
    for i in valid_indices:
        print(icon_conflicts[i])
    input("Press Enter to start.")

    for i in valid_indices:
        path = icon_conflicts[i]
        print("{}: {}".format("Merging", path))

        common_ancestor_hash = run_shell_command(
            "git merge-base ORIG_HEAD master").strip()

        ours_icon = NamedTemporaryFile(delete=False)
        theirs_icon = NamedTemporaryFile(delete=False)
        base_icon = NamedTemporaryFile(delete=False)

        ours_icon.write(
            run_shell_command_binary("git show ORIG_HEAD:{}".format(path)))
        theirs_icon.write(
            run_shell_command_binary("git show master:{}".format(path)))
        base_icon.write(
            run_shell_command_binary("git show {}:{}".format(
                common_ancestor_hash, path)))

        # So it being "open" doesn't prevent other programs from using it
        ours_icon.close()
        theirs_icon.close()
        base_icon.close()

        merge_command = "java -jar {} merge {} {} {} {}".format(
            relative_root + dmitool_path, base_icon.name, ours_icon.name,
            theirs_icon.name, relative_root + path + ".fixed")

        print(merge_command)
        print(run_shell_command(merge_command))
        os.remove(ours_icon.name)
        os.remove(theirs_icon.name)
        os.remove(base_icon.name)
        print(".")
Exemple #51
0
    def export(self,
               out_f=None,
               format='mp3',
               codec=None,
               bitrate=None,
               parameters=None,
               tags=None,
               id3v2_version='4'):
        """
        Export an AudioSegment to a file with given options

        out_f (string):
            Path to destination audio file

        format (string)
            Format for destination audio file.
            ('mp3', 'wav', 'ogg' or other ffmpeg/avconv supported files)

        codec (string)
            Codec used to encoding for the destination.

        bitrate (string)
            Bitrate used when encoding destination file. (64, 92, 128, 256, 312k...)
            Each codec accepts different bitrate arguments so take a look at the 
            ffmpeg documentation for details (bitrate usually shown as -b, -ba or 
            -a:b).

        parameters (string)
            Aditional ffmpeg/avconv parameters

        tags (dict)
            Set metadata information to destination files
            usually used as tags. ({title='Song Title', artist='Song Artist'})

        id3v2_version (string)
            Set ID3v2 version for tags. (default: '4')
        """
        id3v2_allowed_versions = ['3', '4']

        out_f = _fd_or_path_or_tempfile(out_f, 'wb+')
        out_f.seek(0)

        # for wav output we can just write the data directly to out_f
        if format == "wav":
            data = out_f
        else:
            data = NamedTemporaryFile(mode="wb", delete=False)

        wave_data = wave.open(data, 'wb')
        wave_data.setnchannels(self.channels)
        wave_data.setsampwidth(self.sample_width)
        wave_data.setframerate(self.frame_rate)
        # For some reason packing the wave header struct with
        # a float in python 2 doesn't throw an exception
        wave_data.setnframes(int(self.frame_count()))
        wave_data.writeframesraw(self._data)
        wave_data.close()

        # for wav files, we're done (wav data is written directly to out_f)
        if format == 'wav':
            return out_f

        output = NamedTemporaryFile(mode="w+b", delete=False)

        # build converter command to export
        convertion_command = [
            self.converter,
            '-y',  # always overwrite existing files
            "-f",
            "wav",
            "-i",
            data.name,  # input options (filename last)
        ]

        if codec is None:
            codec = self.DEFAULT_CODECS.get(format, None)

        if codec is not None:
            # force audio encoder
            convertion_command.extend(["-acodec", codec])

        if bitrate is not None:
            convertion_command.extend(["-b:a", bitrate])

        if parameters is not None:
            # extend arguments with arbitrary set
            convertion_command.extend(parameters)

        if tags is not None:
            if not isinstance(tags, dict):
                raise InvalidTag("Tags must be a dictionary.")
            else:
                # Extend converter command with tags
                # print(tags)
                for key, value in tags.items():
                    convertion_command.extend(
                        ['-metadata', '{0}={1}'.format(key, value)])

                if format == 'mp3':
                    # set id3v2 tag version
                    if id3v2_version not in id3v2_allowed_versions:
                        raise InvalidID3TagVersion(
                            "id3v2_version not allowed, allowed versions: %s" %
                            id3v2_allowed_versions)
                    convertion_command.extend(
                        ["-id3v2_version", id3v2_version])

        convertion_command.extend([
            "-f",
            format,
            output.name,  # output options (filename last)
        ])

        # read stdin / write stdout
        subprocess.call(
            convertion_command,
            # make converter shut up
            stderr=open(os.devnull))

        output.seek(0)
        out_f.write(output.read())

        data.close()
        output.close()

        os.unlink(data.name)
        os.unlink(output.name)

        out_f.seek(0)
        return out_f
Exemple #52
0
class Tournament(object):
    def __init__(self,
                 players: List[Player],
                 match_generator: MatchGenerator = RoundRobinMatches,
                 name: str = 'axelrod',
                 game: Game = None,
                 turns: int = 200,
                 repetitions: int = 10,
                 noise: float = 0,
                 with_morality: bool = True) -> None:
        """
        Parameters
        ----------
        players : list
            A list of axelrod.Player objects
        match_generator : class
            A class that must be descended from axelrod.MatchGenerator
        name : string
            A name for the tournament
        game : axelrod.Game
            The game object used to score the tournament
        turns : integer
            The number of turns per match
        repetitions : integer
            The number of times the round robin should be repeated
        noise : float
            The probability that a player's intended action should be flipped
        with_morality : boolean
            Whether morality metrics should be calculated
        """
        if game is None:
            self.game = Game()
        else:
            self.game = game
        self.name = name
        self.turns = turns
        self.noise = noise
        self.num_interactions = 0
        self.players = players
        self.repetitions = repetitions
        self.match_generator = match_generator(players, turns, self.game,
                                               self.repetitions, self.noise)
        self._with_morality = with_morality
        self._logger = logging.getLogger(__name__)

    def setup_output(self, filename=None, in_memory=False):
        """Open a CSV writer for tournament output."""
        if in_memory:
            self.interactions_dict = {}
            self.writer = None
        else:
            if filename:
                self.outputfile = open(filename, 'w')
            else:
                # Setup a temporary file
                self.outputfile = NamedTemporaryFile(mode='w')
                filename = self.outputfile.name
            self.writer = csv.writer(self.outputfile, lineterminator='\n')
            # Save filename for loading ResultSet later
            self.filename = filename

    def play(self,
             build_results: bool = True,
             filename: str = None,
             processes: int = None,
             progress_bar: bool = True,
             keep_interactions: bool = False,
             in_memory: bool = False) -> ResultSetFromFile:
        """
        Plays the tournament and passes the results to the ResultSet class

        Parameters
        ----------
        build_results : bool
            whether or not to build a results st
        filename : string
            name of output file
        processes : integer
            The number of processes to be used for parallel processing
        progress_bar : bool
            Whether or not to create a progress bar which will be updated
        keep_interactions : bool
            Whether or not to load the interactions in to memory
        in_memory : bool
            By default interactions are written to a file.
            If this is True they will be kept in memory.
            This is not advised for large tournaments.

        Returns
        -------
        axelrod.ResultSetFromFile
        """
        if progress_bar:
            self.progress_bar = tqdm.tqdm(total=len(self.match_generator),
                                          desc="Playing matches")

        if on_windows and (filename is None):  # pragma: no cover
            in_memory = True

        self.setup_output(filename, in_memory)

        if not build_results and not filename:
            warnings.warn("Tournament results will not be accessible since "
                          "build_results=False and no filename was supplied.")

        if (processes is None) or (on_windows):
            self._run_serial(progress_bar=progress_bar)
        else:
            self._run_parallel(processes=processes, progress_bar=progress_bar)

        if progress_bar:
            self.progress_bar.close()

        # Make sure that python has finished writing to disk
        if not in_memory:
            self.outputfile.flush()

        if build_results:
            return self._build_result_set(progress_bar=progress_bar,
                                          keep_interactions=keep_interactions,
                                          in_memory=in_memory)
        elif not in_memory:
            self.outputfile.close()

    def _build_result_set(self,
                          progress_bar: bool = True,
                          keep_interactions: bool = False,
                          in_memory: bool = False):
        """
        Build the result set (used by the play method)

        Returns
        -------
        axelrod.BigResultSet
        """
        if not in_memory:
            result_set = ResultSetFromFile(
                filename=self.filename,
                progress_bar=progress_bar,
                num_interactions=self.num_interactions,
                repetitions=self.repetitions,
                players=[str(p) for p in self.players],
                keep_interactions=keep_interactions,
                game=self.game)
            self.outputfile.close()
        else:
            result_set = ResultSet(players=[str(p) for p in self.players],
                                   interactions=self.interactions_dict,
                                   repetitions=self.repetitions,
                                   progress_bar=progress_bar,
                                   game=self.game)
        return result_set

    def _run_serial(self, progress_bar: bool = False) -> bool:
        """
        Run all matches in serial

        Parameters
        ----------

        progress_bar : bool
            Whether or not to update the tournament progress bar
        """
        chunks = self.match_generator.build_match_chunks()

        for chunk in chunks:
            results = self._play_matches(chunk)
            self._write_interactions(results)

            if progress_bar:
                self.progress_bar.update(1)

        return True

    def _write_interactions(self, results):
        """Write the interactions to file or to a dictionary"""
        if self.writer is not None:
            self._write_interactions_to_file(results)
        elif self.interactions_dict is not None:
            self._write_interactions_to_dict(results)

    def _write_interactions_to_file(self, results):
        """Write the interactions to csv."""
        for index_pair, interactions in results.items():
            for interaction in interactions:
                row = list(index_pair)
                row.append(str(self.players[index_pair[0]]))
                row.append(str(self.players[index_pair[1]]))
                history1 = "".join([i[0] for i in interaction])
                history2 = "".join([i[1] for i in interaction])
                row.append(history1)
                row.append(history2)
                self.writer.writerow(row)
                self.num_interactions += 1

    def _write_interactions_to_dict(self, results):
        """Write the interactions to memory"""
        for index_pair, interactions in results.items():
            for interaction in interactions:
                try:
                    self.interactions_dict[index_pair].append(interaction)
                except KeyError:
                    self.interactions_dict[index_pair] = [interaction]
                self.num_interactions += 1

    def _run_parallel(self,
                      processes: int = 2,
                      progress_bar: bool = False) -> bool:
        """
        Run all matches in parallel

        Parameters
        ----------

        progress_bar : bool
            Whether or not to update the tournament progress bar
        """
        # At first sight, it might seem simpler to use the multiprocessing Pool
        # Class rather than Processes and Queues. However, Pool can only accept
        # target functions which can be pickled and instance methods cannot.
        work_queue = Queue()
        done_queue = Queue()
        workers = self._n_workers(processes=processes)

        chunks = self.match_generator.build_match_chunks()
        for chunk in chunks:
            work_queue.put(chunk)

        self._start_workers(workers, work_queue, done_queue)
        self._process_done_queue(workers,
                                 done_queue,
                                 progress_bar=progress_bar)

        return True

    def _n_workers(self, processes: int = 2) -> int:
        """
        Determines the number of parallel processes to use.

        Returns
        -------
        integer
        """
        if (2 <= processes <= cpu_count()):
            n_workers = processes
        else:
            n_workers = cpu_count()
        return n_workers

    def _start_workers(self, workers: int, work_queue: Queue,
                       done_queue: Queue) -> bool:
        """
        Initiates the sub-processes to carry out parallel processing.

        Parameters
        ----------
        workers : integer
            The number of sub-processes to create
        work_queue : multiprocessing.Queue
            A queue containing an entry for each round robin to be processed
        done_queue : multiprocessing.Queue
            A queue containing the output dictionaries from each round robin
        """
        for worker in range(workers):
            process = Process(target=self._worker,
                              args=(work_queue, done_queue))
            work_queue.put('STOP')
            process.start()
        return True

    def _process_done_queue(self,
                            workers: int,
                            done_queue: Queue,
                            progress_bar: bool = False):
        """
        Retrieves the matches from the parallel sub-processes

        Parameters
        ----------
        workers : integer
            The number of sub-processes in existence
        done_queue : multiprocessing.Queue
            A queue containing the output dictionaries from each round robin
        progress_bar : bool
            Whether or not to update the tournament progress bar
        """
        stops = 0
        while stops < workers:
            results = done_queue.get()

            if results == 'STOP':
                stops += 1
            else:
                self._write_interactions(results)

                if progress_bar:
                    self.progress_bar.update(1)
        return True

    def _worker(self, work_queue: Queue, done_queue: Queue):
        """
        The work for each parallel sub-process to execute.

        Parameters
        ----------
        work_queue : multiprocessing.Queue
            A queue containing an entry for each round robin to be processed
        done_queue : multiprocessing.Queue
            A queue containing the output dictionaries from each round robin
        """
        for chunk in iter(work_queue.get, 'STOP'):
            interactions = self._play_matches(chunk)
            done_queue.put(interactions)
        done_queue.put('STOP')
        return True

    def _play_matches(self, chunk):
        """
        Play matches in a given chunk.

        Parameters
        ----------
        chunk : tuple (index pair, match_parameters, repetitions)
            match_parameters are also a tuple: (turns, game, noise)

        Returns
        -------
        interactions : dictionary
            Mapping player index pairs to results of matches:

                (0, 1) -> [('C', 'D'), ('D', 'C'),...]
        """
        interactions = defaultdict(list)
        index_pair, match_params, repetitions = chunk
        p1_index, p2_index = index_pair
        player1 = self.players[p1_index].clone()
        player2 = self.players[p2_index].clone()
        players = (player1, player2)
        params = [players]
        params.extend(match_params)
        match = Match(*params)
        for _ in range(repetitions):
            match.play()
            interactions[index_pair].append(match.result)
        return interactions
    def test_filetransfer(self,
                          devid,
                          authtoken,
                          path="/etc/mender/mender.conf",
                          content_assertion=None):
        # download a file and check its content
        r = download_file(path, devid, authtoken)

        assert r.status_code == 200, r.json()
        if content_assertion:
            assert content_assertion in str(r.content)
        assert (
            r.headers.get("Content-Disposition") == 'attachment; filename="' +
            os.path.basename(path) + '"')
        assert r.headers.get("Content-Type") == "application/octet-stream"
        assert r.headers.get("X-Men-File-Gid") == "0"
        assert r.headers.get("X-Men-File-Uid") == "0"
        assert r.headers.get("X-Men-File-Mode") == "600"
        assert r.headers.get("X-Men-File-Path") == "/etc/mender/mender.conf"
        assert r.headers.get("X-Men-File-Size") != ""

        # wrong request, path is relative
        path = "relative/path"
        r = download_file(path, devid, authtoken)
        assert r.status_code == 400, r.json()
        assert r.json().get("error") == "bad request: path: must be absolute."

        # wrong request, no such file or directory
        path = "/does/not/exist"
        r = download_file(path, devid, authtoken)
        assert r.status_code == 400, r.json()
        assert "/does/not/exist: no such file or directory" in r.json().get(
            "error")

        try:
            # create a 40MB random file
            f = NamedTemporaryFile(delete=False)
            for i in range(40 * 1024):
                f.write(os.urandom(1024))
            f.close()

            # random uid and gid
            uid = random.randint(100, 200)
            gid = random.randint(100, 200)

            # upload the file
            r = upload_file(
                "/tmp/random.bin",
                open(f.name, "rb"),
                devid,
                authtoken,
                mode="600",
                uid=str(uid),
                gid=str(gid),
            )
            assert r.status_code == 201, r.json()

            # download the file
            path = "/tmp/random.bin"
            r = download_file(path, devid, authtoken)
            assert r.status_code == 200, r.json()
            assert (r.headers.get("Content-Disposition") ==
                    'attachment; filename="random.bin"')
            assert r.headers.get("Content-Type") == "application/octet-stream"
            assert r.headers.get("X-Men-File-Mode") == "600"
            assert r.headers.get("X-Men-File-Uid") == str(uid)
            assert r.headers.get("X-Men-File-Gid") == str(gid)
            assert r.headers.get("X-Men-File-Path") == "/tmp/random.bin"
            assert r.headers.get("X-Men-File-Size") == str(40 * 1024 * 1024)

            filename_download = f.name + ".download"
            with open(filename_download, "wb") as fw:
                fw.write(r.content)

            # verify the file is not corrupted
            assert md5sum(filename_download) == md5sum(f.name)
        finally:
            os.unlink(f.name)
            if os.path.isfile(f.name + ".download"):
                os.unlink(f.name + ".download")

        # wrong request, path is relative
        r = upload_file(
            "relative/path/dummy.txt",
            io.StringIO("dummy"),
            devid,
            authtoken,
            mode="600",
            uid="0",
            gid="0",
        )
        assert r.status_code == 400, r.json()
        assert r.json().get("error") == "bad request: path: must be absolute."

        # wrong request, cannot write the file
        r = upload_file(
            "/does/not/exist/dummy.txt",
            io.StringIO("dummy"),
            devid,
            authtoken,
            mode="600",
            uid="0",
            gid="0",
        )
        assert r.status_code == 400, r.json()
        assert "failed to create target file" in r.json().get("error")
Exemple #54
0
def download_file(url=None,
                  filename=None,
                  headers={},
                  username=None,
                  password=None,
                  verify=False,
                  session=None):
    '''
    Download a file and return its local path; this function is primarily meant to be called by the download function below
    
    Parameters:
        url: str
            Remote URL to download

        filename: str
            Local file name

        headers: dict
            Dictionary containing the headers to be passed to the requests get call

        username: str
            user name to be used in HTTP authentication

        password: str
            password to be used in HTTP authentication

        verify: bool
            Flag indicating whether or not to verify the SSL/TLS certificate

        session: requests.Session object
            Requests session object that allows you to persist things like HTTP authentication through multiple calls

    Returns:
        String containing the local file name

    '''

    if session is None:
        session = requests.Session()

    if username != None:
        session.auth = (username, password)

    # check if the file exists, and if so, set the last modification time in the header
    # this allows you to avoid re-downloading files that haven't changed
    if os.path.exists(filename):
        headers['If-Modified-Since'] = (datetime.datetime.utcfromtimestamp(
            os.path.getmtime(filename))).strftime('%a, %d %b %Y %H:%M:%S GMT')

    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=ResourceWarning)
        fsrc = session.get(url, stream=True, verify=verify, headers=headers)

    # need to delete the If-Modified-Since header so it's not set in the dictionary in subsequent calls
    if headers.get('If-Modified-Since') != None:
        del headers['If-Modified-Since']

    # the file hasn't changed
    if fsrc.status_code == 304:
        logging.info('File is current: ' + filename)
        fsrc.close()
        return filename

    # file not found
    if fsrc.status_code == 404:
        logging.error('Remote file not found: ' + url)
        fsrc.close()
        return None

    # authentication issues
    if fsrc.status_code == 401 or fsrc.status_code == 403:
        logging.error('Unauthorized: ' + url)
        fsrc.close()
        return None

    if fsrc.status_code == 200:
        logging.info('Downloading ' + url + ' to ' + filename)
    else:
        logging.error(fsrc.reason)
        fsrc.close()
        return None

    ftmp = NamedTemporaryFile(delete=False)

    with open(ftmp.name, 'wb') as f:
        copyfileobj(fsrc.raw, f)

    # make sure the directory exists
    if not os.path.exists(
            os.path.dirname(filename)) and os.path.dirname(filename) != '':
        os.makedirs(os.path.dirname(filename))

    # if the download was successful, copy to data directory
    copy(ftmp.name, filename)

    fsrc.close()
    ftmp.close()

    logging.info('Download complete: ' + filename)

    return filename
Exemple #55
0
def ccopen(source, *args, **kargs):
    """Guess the identity of a particular log file and return an instance of it.

    Inputs:
        source - a single logfile, a list of logfiles (for a single job),
                 an input stream, or an URL pointing to a log file.
        *args, **kargs - arguments and keyword arguments passed to filetype

    Returns:
      one of ADF, DALTON, GAMESS, GAMESS UK, Gaussian, Jaguar, Molpro, MOPAC,
      NWChem, ORCA, Psi, QChem, CJSON or None (if it cannot figure it out or
      the file does not exist).
    """

    inputfile = None
    is_stream = False

    # Check if source is a link or contains links. Retrieve their content.
    # Try to open the logfile(s), using openlogfile, if the source is a string (filename)
    # or list of filenames. If it can be read, assume it is an open file object/stream.
    is_string = isinstance(source, str)
    is_url = True if is_string and URL_PATTERN.match(source) else False
    is_listofstrings = isinstance(source, list) and all(
        [isinstance(s, str) for s in source])
    if is_string or is_listofstrings:
        # Process links from list (download contents into temporary location)
        if is_listofstrings:
            filelist = []
            for filename in source:
                if not URL_PATTERN.match(filename):
                    filelist.append(filename)
                else:
                    try:
                        response = urlopen(filename)
                        tfile = NamedTemporaryFile(delete=False)
                        tfile.write(response.read())
                        # Close the file because Windows won't let open it second time
                        tfile.close()
                        filelist.append(tfile.name)
                        # Delete temporary file when the program finishes
                        atexit.register(os.remove, tfile.name)
                    except (ValueError, URLError) as error:
                        if not kargs.get('quiet', False):
                            (errno, strerror) = error.args
                        return None
            source = filelist

        if not is_url:
            try:
                inputfile = logfileparser.openlogfile(source)
            except IOError as error:
                if not kargs.get('quiet', False):
                    (errno, strerror) = error.args
                return None
        else:
            try:
                response = urlopen(source)
                is_stream = True

                # Retrieve filename from URL if possible
                filename = re.findall("\w+\.\w+", source.split('/')[-1])
                filename = filename[0] if filename else ""

                inputfile = logfileparser.openlogfile(filename,
                                                      object=response.read())
            except (ValueError, URLError) as error:
                if not kargs.get('quiet', False):
                    (errno, strerror) = error.args
                return None

    elif hasattr(source, "read"):
        inputfile = source
        is_stream = True

    # Streams are tricky since they don't have seek methods or seek won't work
    # by design even if it is present. We solve this now by reading in the
    # entire stream and using a StringIO buffer for parsing. This might be
    # problematic for very large streams. Slow streams might also be an issue if
    # the parsing is not instantaneous, but we'll deal with such edge cases
    # as they arise. Ideally, in the future we'll create a class dedicated to
    # dealing with these issues, supporting both files and streams.
    if is_stream:
        try:
            inputfile.seek(0, 0)
        except (AttributeError, IOError):
            contents = inputfile.read()
            try:
                inputfile = io.StringIO(contents)
            except:
                inputfile = io.StringIO(unicode(contents))
            inputfile.seek(0, 0)

    # Proceed to return an instance of the logfile parser only if the filetype
    # could be guessed. Need to make sure the input file is closed before creating
    # an instance, because parsers will handle opening/closing on their own.
    # If the input file is a CJSON file and not a standard compchemlog file, don't
    # guess the file.
    if kargs.get("cjson", False):
        filetype = cjsonreader.CJSON
    else:
        filetype = guess_filetype(inputfile)

    # Proceed to return an instance of the logfile parser only if the filetype
    # could be guessed. Need to make sure the input file is closed before creating
    # an instance, because parsers will handle opening/closing on their own.
    if filetype:
        # We're going to clase and reopen below anyway, so this is just to avoid
        # the missing seek method for fileinput.FileInput. In the long run
        # we need to refactor to support for various input types in a more
        # centralized fashion.
        if is_listofstrings:
            pass
        else:
            inputfile.seek(0, 0)
        if not is_stream:
            inputfile.close()
            return filetype(source, *args, **kargs)
        return filetype(inputfile, *args, **kargs)
    def test_filetransfer_limits_upload(self, mender_device, devid, auth):
        authtoken = auth.get_auth_token()
        """Tests the file transfer features with limits"""
        set_limits(
            mender_device,
            {
                "Enabled": True,
                "FileTransfer": {
                    "Chroot": "/var/lib/mender/filetransfer"
                },
            },
            auth,
            devid,
        )
        mender_device.run("mkdir -p /var/lib/mender/filetransfer")

        logger.info(
            "-- testcase: File Transfer limits: file outside chroot; upload forbidden"
        )
        f = NamedTemporaryFile(delete=False)
        for i in range(40 * 1024):
            f.write(os.urandom(1024))
        f.close()
        r = upload_file(
            "/usr/random.bin",
            open(f.name, "rb"),
            devid,
            authtoken,
        )

        assert r.status_code == 400, r.json()
        assert (r.json().get("error") ==
                "access denied: the target file path is outside chroot")

        set_limits(
            mender_device,
            {
                "Enabled": True,
                "FileTransfer": {
                    "Chroot": "/var/lib/mender/filetransfer",
                    "FollowSymLinks":
                    True,  # in the image /var/lib/mender is a symlink
                },
            },
            auth,
            devid,
        )

        logger.info(
            "-- testcase: File Transfer limits: file inside chroot; upload allowed"
        )
        f = NamedTemporaryFile(delete=False)
        f.write(os.urandom(16))
        f.close()
        # upload a file
        r = upload_file(
            "/var/lib/mender/filetransfer/random.bin",
            open(f.name, "rb"),
            devid,
            authtoken,
        )

        assert r.status_code == 201

        set_limits(
            mender_device,
            {
                "Enabled": True,
                "FileTransfer": {
                    "MaxFileSize": 16378,
                    "FollowSymLinks": True
                },
            },
            auth,
            devid,
        )

        logger.info(
            "-- testcase: File Transfer limits: file size over the limit; upload forbidden"
        )
        f = NamedTemporaryFile(delete=False)
        for i in range(128 * 1024):
            f.write(b"ok")
        f.close()
        r = upload_file(
            "/tmp/random.bin",
            open(f.name, "rb"),
            devid,
            authtoken,
        )

        assert r.status_code == 400, r.json()
        assert (
            r.json().get("error") ==
            "failed to write file chunk: transmitted bytes limit exhausted")

        set_limits(
            mender_device,
            {
                "Enabled": True,
                "FileTransfer": {
                    "FollowSymLinks": True,
                    "Counters": {
                        "MaxBytesRxPerMinute": 16784
                    },
                },
            },
            auth,
            devid,
        )

        logger.info(
            "-- testcase: File Transfer limits: transfers during last minute over the limit; upload forbidden"
        )
        f = NamedTemporaryFile(delete=False)
        for i in range(128 * 1024):
            f.write(b"ok")
        f.close()
        upload_file(
            "/tmp/random-0.bin",
            open(f.name, "rb"),
            devid,
            authtoken,
        )
        upload_file(
            "/tmp/random-1.bin",
            open(f.name, "rb"),
            devid,
            authtoken,
        )
        logger.info(
            "-- testcase: File Transfer limits: sleeping to gather the avg")

        time.sleep(
            32)  # wait for mender-connect to calculate the 1m exp moving avg
        mender_device.run("kill -USR1 `pidof mender-connect`"
                          )  # USR1 makes mender-connect print status

        r = upload_file(
            "/tmp/random-2.bin",
            open(f.name, "rb"),
            devid,
            authtoken,
        )

        assert r.status_code == 400, r.json()
        assert r.json().get("error") == "transmitted bytes limit exhausted"

        logger.info(
            "-- testcase: File Transfer limits: transfers during last minute: test_filetransfer_limits_upload sleeping 64s to be able to transfer again"
        )
        # let's rest some more and increase the limit and try again
        time.sleep(64)
        mender_device.run("kill -USR1 `pidof mender-connect`"
                          )  # USR1 makes mender-connect print status

        logger.info(
            "-- testcase: File Transfer limits: transfers during last minute below the limit; upload allowed"
        )
        f = NamedTemporaryFile(delete=False)
        for i in range(64):
            f.write(b"ok")
        f.close()
        # upload a file
        r = upload_file(
            "/tmp/random-a.bin",
            open(f.name, "rb"),
            devid,
            authtoken,
        )
        mender_device.run("kill -USR1 `pidof mender-connect`"
                          )  # USR1 makes mender-connect print status

        assert r.status_code == 201

        set_limits(
            mender_device,
            {
                "Enabled": True,
                "FileTransfer": {
                    "Chroot": "/var/lib/mender/filetransfer",
                    "FollowSymLinks":
                    True,  # in the image /var/lib/mender is a symlink
                    "PreserveMode": True,
                },
            },
            auth,
            devid,
        )

        logger.info("-- testcase: File Transfer limits: preserve modes;")
        f = NamedTemporaryFile(delete=False)
        f.write(os.urandom(16))
        f.close()
        r = upload_file(
            "/var/lib/mender/filetransfer/modes.bin",
            open(f.name, "rb"),
            devid,
            authtoken,
            mode="4711",
        )
        modes_ls = mender_device.run(
            "ls -al /var/lib/mender/filetransfer/modes.bin")
        logger.info(
            "test_filetransfer_limits_upload ls -al /var/lib/mender/filetransfer/modes.bin:\n%s"
            % modes_ls)

        assert modes_ls.startswith("-rws--x--x")
        assert r.status_code == 201

        set_limits(
            mender_device,
            {
                "Enabled": True,
                "FileTransfer": {
                    "Chroot": "/var/lib/mender/filetransfer",
                    "FollowSymLinks":
                    True,  # in the image /var/lib/mender is a symlink
                    "PreserveOwner": True,
                    "PreserveGroup": True,
                },
            },
            auth,
            devid,
        )

        logger.info(
            "-- testcase: File Transfer limits: preserve owner and group;")
        f = NamedTemporaryFile(delete=False)
        f.write(os.urandom(16))
        f.close()
        gid = int(mender_device.run("cat /etc/group | tail -1| cut -f3 -d:"))
        uid = int(mender_device.run("cat /etc/passwd | tail -1| cut -f3 -d:"))
        logger.info("test_filetransfer_limits_upload gid/uid %d/%d", gid, uid)
        r = upload_file(
            "/var/lib/mender/filetransfer/ownergroup.bin",
            open(f.name, "rb"),
            devid,
            authtoken,
            uid=str(uid),
            gid=str(gid),
        )
        owner_group = mender_device.run(
            "ls -aln /var/lib/mender/filetransfer/ownergroup.bin | cut -f 3,4 -d' '"
        )

        assert owner_group == str(uid) + " " + str(gid) + "\n"
        assert r.status_code == 201
Exemple #57
0
def main():

    # some needed vars that are defined in the file inverter_bench_defs.py for convenience
    input_glob = inverter_bench_defs.input_glob
    input_var1 = inverter_bench_defs.input_var1
    input_var2 = inverter_bench_defs.input_var2
    debug = inverter_bench_defs.debug
    stdout = inverter_bench_defs.stdout
    backup = inverter_bench_defs.backup
    executable = inverter_bench_defs.executable

    switch = 0
    # check if input-file is given at argv[1]
    if (len(sys.argv) != 2):
        print "no input- or reference-file given. Perform all existing benchmarks!"
    else:
        input_name = sys.argv[1]
        print "\tbenchmarking \"" + executable + "\" using inputfile \"" + sys.argv[
            1]
        switch = 1

    # performs as many tests as specified in the defs-file
    if (switch == 0):
        size1 = len(input_var1)
        size2 = len(input_var2)
    else:
        size1 = 1
        size2 = 1

    for iteration1 in range(size1):
        for iteration2 in range(size2):
            iteration = iteration1 * size2 + iteration2
            size = size1 * size2
            print '\tbenchmark %i of %i' % (iteration + 1, size)

            # select input-file if wished, otherwise create one
            if (switch == 1):
                inputfile = sys.argv[1:]
                args = ['./' + executable] + sys.argv[1:]
            else:
                # open file "tmpaaaaaaaa", deleting whatever content was in it before
                f = NamedTemporaryFile(delete=False)
                f.write(input_glob + input_var1[iteration1] +
                        input_var2[iteration2])
                f.close()
                args = ['./' + executable] + [f.name]

            # run the prog
            if (stdout):
                subject = Popen(args)
            else:
                subject = Popen(args, stdout=PIPE)

            subject.wait()
            if (switch == 0):
                if (backup):
                    # save inputfile to a different file
                    backupname = 'input_' + str(iteration)
                    shutil.copy(f.name, backupname)

            if subject.returncode == 0:
                print "\tProgram completed successfully"
            else:
                print "\tProgram terminated with exit code %i" % (
                    subject.returncode)
                continue

    if (switch == 0):
        os.remove(f.name)
Exemple #58
0
class BaseTaskRunner(LoggingMixin):
    """
    Runs Airflow task instances by invoking the `airflow tasks run` command with raw
    mode enabled in a subprocess.

    :param local_task_job: The local task job associated with running the
        associated task instance.
    :type local_task_job: airflow.jobs.local_task_job.LocalTaskJob
    """
    def __init__(self, local_task_job):
        # Pass task instance context into log handlers to setup the logger.
        super().__init__(local_task_job.task_instance)
        self._task_instance = local_task_job.task_instance

        popen_prepend = []
        if self._task_instance.run_as_user:
            self.run_as_user = self._task_instance.run_as_user
        else:
            try:
                self.run_as_user = conf.get('core', 'default_impersonation')
            except AirflowConfigException:
                self.run_as_user = None

        self._error_file = NamedTemporaryFile(delete=True)

        # Add sudo commands to change user if we need to. Needed to handle SubDagOperator
        # case using a SequentialExecutor.
        self.log.debug("Planning to run as the %s user", self.run_as_user)
        if self.run_as_user and (self.run_as_user != getuser()):
            # We want to include any environment variables now, as we won't
            # want to have to specify them in the sudo call - they would show
            # up in `ps` that way! And run commands now, as the other user
            # might not be able to run the cmds to get credentials
            cfg_path = tmp_configuration_copy(chmod=0o600,
                                              include_env=True,
                                              include_cmds=True)

            # Give ownership of file to user; only they can read and write
            subprocess.check_call([
                'sudo', 'chown', self.run_as_user, cfg_path,
                self._error_file.name
            ],
                                  close_fds=True)

            # propagate PYTHONPATH environment variable
            pythonpath_value = os.environ.get(PYTHONPATH_VAR, '')
            popen_prepend = ['sudo', '-E', '-H', '-u', self.run_as_user]

            if pythonpath_value:
                popen_prepend.append(f'{PYTHONPATH_VAR}={pythonpath_value}')

        else:
            # Always provide a copy of the configuration file settings. Since
            # we are running as the same user, and can pass through environment
            # variables then we don't need to include those in the config copy
            # - the runner can read/execute those values as it needs
            cfg_path = tmp_configuration_copy(chmod=0o600,
                                              include_env=False,
                                              include_cmds=False)

        self._cfg_path = cfg_path
        self._command = (popen_prepend + self._task_instance.command_as_list(
            raw=True,
            pickle_id=local_task_job.pickle_id,
            mark_success=local_task_job.mark_success,
            job_id=local_task_job.id,
            pool=local_task_job.pool,
            cfg_path=cfg_path,
        ) + ["--error-file", self._error_file.name])
        self.process = None

    def deserialize_run_error(self) -> Optional[Union[str, Exception]]:
        """Return task runtime error if its written to provided error file."""
        return load_error_file(self._error_file)

    def _read_task_logs(self, stream):
        while True:
            line = stream.readline()
            if isinstance(line, bytes):
                line = line.decode('utf-8')
            if not line:
                break
            self.log.info(
                'Job %s: Subtask %s %s',
                self._task_instance.job_id,
                self._task_instance.task_id,
                line.rstrip('\n'),
            )

    def run_command(self, run_with=None):
        """
        Run the task command.

        :param run_with: list of tokens to run the task command with e.g. ``['bash', '-c']``
        :type run_with: list
        :return: the process that was run
        :rtype: subprocess.Popen
        """
        run_with = run_with or []
        full_cmd = run_with + self._command

        self.log.info("Running on host: %s", get_hostname())
        self.log.info('Running: %s', full_cmd)

        if IS_WINDOWS:
            proc = subprocess.Popen(
                full_cmd,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                universal_newlines=True,
                close_fds=True,
                env=os.environ.copy(),
            )
        else:
            proc = subprocess.Popen(
                full_cmd,
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                universal_newlines=True,
                close_fds=True,
                env=os.environ.copy(),
                preexec_fn=os.setsid,
            )

        # Start daemon thread to read subprocess logging output
        log_reader = threading.Thread(
            target=self._read_task_logs,
            args=(proc.stdout, ),
        )
        log_reader.daemon = True
        log_reader.start()
        return proc

    def start(self):
        """Start running the task instance in a subprocess."""
        raise NotImplementedError()

    def return_code(self) -> Optional[int]:
        """
        :return: The return code associated with running the task instance or
            None if the task is not yet done.
        :rtype: int
        """
        raise NotImplementedError()

    def terminate(self) -> None:
        """Force kill the running task instance."""
        raise NotImplementedError()

    def on_finish(self) -> None:
        """A callback that should be called when this is done running."""
        if self._cfg_path and os.path.isfile(self._cfg_path):
            if self.run_as_user:
                subprocess.call(['sudo', 'rm', self._cfg_path], close_fds=True)
            else:
                os.remove(self._cfg_path)
        try:
            self._error_file.close()
        except FileNotFoundError:
            # The subprocess has deleted this file before we do
            # so we ignore
            pass
def _get_temp_fpath():
    from tempfile import NamedTemporaryFile
    tmp_file = NamedTemporaryFile(delete=False)
    tmp_name = tmp_file.name
    tmp_file.close()
    return tmp_name
Exemple #60
0
class TFTransformTest(TestCase):
    schema_txt = """
            feature {
                name: "test_feature"
                value_count {
                    min: 1
                    max: 1
                }
                type: FLOAT
                presence {
                    min_count: 1
                }
            }
        """.encode("utf-8")

    def setUp(self):
        self.schema_file = NamedTemporaryFile(delete=False)
        self.schema_file_name = self.schema_file.name
        self.schema_file.write(self.schema_txt)
        self.schema_file.close()
        self.feature_spec = schema_txt_file_to_feature_spec(
            self.schema_file_name)

        self.train_data = NamedTemporaryFile(suffix=".tfrecords", delete=False)
        self.train_data_file = self.train_data.name
        self.train_data.close()
        train_dicts = [{"test_feature": [0.1]}, {"test_feature": [-0.1]}]
        with tf.python_io.TFRecordWriter(self.train_data_file) as writer:
            for train_dict in train_dicts:
                tf_example = build_tf_example_from_dict(train_dict)
                writer.write(tf_example.SerializeToString())

        self.eval_data = NamedTemporaryFile(suffix=".tfrecords", delete=False)
        self.eval_data_file = self.eval_data.name
        self.eval_data.close()
        eval_dicts = [{"test_feature": [0.2]}, {"test_feature": [-0.2]}]
        with tf.python_io.TFRecordWriter(self.eval_data_file) as writer:
            for eval_dict in eval_dicts:
                tf_example = build_tf_example_from_dict(eval_dict)
                writer.write(tf_example.SerializeToString())

        self.output_dir = mkdtemp()
        self.temp_dir = mkdtemp()

    def test_transform(self):
        pipeline_args = ["--runner=DirectRunner"]
        tft_args = [
            "--training_data=%s" % self.train_data_file,
            "--evaluation_data=%s" % self.eval_data_file,
            "--output_dir=%s" % self.output_dir,
            "--temp_location=%s" % self.temp_dir,
            "--schema_file=%s" % self.schema_file_name
        ]
        args = tft_args + pipeline_args
        TFTransform(preprocessing_fn=dummy_preprocessing_fn).run(args=args)

        # test output structure
        sub_folders = os.listdir(self.output_dir)
        self.assertEquals(
            set(sub_folders),
            {"evaluation", "training", "transform_fn", "transformed_metadata"})
        transformed_train_files = [
            f for f in os.listdir(os.path.join(self.output_dir, "training"))
            if f.endswith(".tfrecords")
        ]
        self.assertEquals(len(transformed_train_files), 1)
        transformed_eval_files = [
            f for f in os.listdir(os.path.join(self.output_dir, "evaluation"))
            if f.endswith(".tfrecords")
        ]
        self.assertEquals(len(transformed_eval_files), 1)
        transform_fn_file = os.path.join(self.output_dir, "transform_fn",
                                         "saved_model.pb")
        self.assertTrue(os.path.exists(transform_fn_file))

        # test transformed training data
        path = os.path.join(self.output_dir, "training",
                            transformed_train_files[0])
        transformed_train = [
            js["features"]["feature"]["test_feature_fx"]["floatList"]["value"]
            for js in parse_tf_records(path)
        ]
        transformed_train.sort(key=lambda x: x[0])
        self.assertEqual(len(transformed_train), 2)
        self.assertEqual(transformed_train, [[-1.0], [1.0]])

        # test transformed evaluation data
        path = os.path.join(self.output_dir, "evaluation",
                            transformed_eval_files[0])
        transformed_eval = [
            js["features"]["feature"]["test_feature_fx"]["floatList"]["value"]
            for js in parse_tf_records(path)
        ]
        transformed_eval.sort(key=lambda x: x[0])
        self.assertEqual(len(transformed_eval), 2)
        # transformed_eval is derived from the z-score transformation based on the training data
        # eval_value = (raw_value - train_mean) / train_std_dev
        self.assertEqual(transformed_eval, [[-2.0], [2.0]])

    def test_no_train_no_transform_fn_dir(self):
        pipeline_args = ["--runner=DirectRunner"]
        tft_args = [
            "--evaluation_data=%s" % self.eval_data_file,
            "--output_dir=%s" % self.output_dir,
            "--temp_location=%s" % self.temp_dir,
            "--schema_file=%s" % self.schema_file_name
        ]
        args = tft_args + pipeline_args
        try:
            TFTransform(preprocessing_fn=dummy_preprocessing_fn).run(args=args)
            self.assertTrue(False)
        except ValueError:
            self.assertTrue(True)

    def tearDown(self):
        os.remove(self.train_data_file)
        os.remove(self.eval_data_file)
        os.remove(self.schema_file.name)
        shutil.rmtree(self.output_dir)
        shutil.rmtree(self.temp_dir)