Example #1
0
    def test_pickle(self):
        a = T.scalar()  # the a is for 'anonymous' (un-named).
        x, s = T.scalars('xs')

        f = function([x, In(a, value=1.0, name='a'), In(s, value=0.0, update=s+a*x, mutable=True)], s+a*x)

        try:
            # Note that here we also test protocol 0 on purpose, since it
            # should work (even though one should not use it).
            g = pickle.loads(pickle.dumps(f, protocol=0))
            g = pickle.loads(pickle.dumps(f, protocol=-1))
        except NotImplementedError as e:
            if e[0].startswith('DebugMode is not picklable'):
                return
            else:
                raise
        # if they both return, assume  that they return equivalent things.
        # print [(k,id(k)) for k in f.finder.keys()]
        # print [(k,id(k)) for k in g.finder.keys()]

        self.assertFalse(g.container[0].storage is f.container[0].storage)
        self.assertFalse(g.container[1].storage is f.container[1].storage)
        self.assertFalse(g.container[2].storage is f.container[2].storage)
        self.assertFalse(x in g.container)
        self.assertFalse(x in g.value)

        self.assertFalse(g.value[1] is f.value[1])  # should not have been copied
        self.assertFalse(g.value[2] is f.value[2])  # should have been copied because it is mutable.
        self.assertFalse((g.value[2] != f.value[2]).any())  # its contents should be identical

        self.assertTrue(f(2, 1) == g(2))  # they should be in sync, default value should be copied.
        self.assertTrue(f(2, 1) == g(2))  # they should be in sync, default value should be copied.
        f(1, 2)  # put them out of sync
        self.assertFalse(f(1, 2) == g(1, 2))  # they should not be equal anymore.
Example #2
0
 def test_subclass_coerce_pymimedata(self):
     md = PyMimeData(data=0)
     md2 = PMDSubclass.coerce(md)
     self.assertTrue(isinstance(md2, PMDSubclass))
     self.assertTrue(md2.hasFormat(PyMimeData.MIME_TYPE))
     self.assertFalse(md2.hasFormat(PyMimeData.NOPICKLE_MIME_TYPE))
     self.assertEqual(md2.data(PyMimeData.MIME_TYPE).data(), dumps(int)+dumps(0))
Example #3
0
    def save(self, path, items):
        # TODO: purge old cache
        with atomic_file(path) as f:
            c = 0
            f.write(struct.pack("I", c))
            # check is marshalable and compatible with broadcast
            can_marshal = marshalable(items)
            for v in items:
                if can_marshal:
                    try:
                        r = 0, marshal.dumps(v)
                    except Exception:
                        r = 1, cPickle.dumps(v, -1)
                        can_marshal = False
                else:
                    r = 1, cPickle.dumps(v, -1)
                f.write(msgpack.packb(r))
                c += 1
                yield v

            bytes = f.tell()
            if bytes > 10 << 20:
                logger.warning("cached result is %dMB (larger than 10MB)", bytes >> 20)
            # count
            f.seek(0)
            f.write(struct.pack("I", c))
Example #4
0
    def __init__(self, data=None, pickle=True):
        """ Initialise the instance.
        """
        QtCore.QMimeData.__init__(self)

        # Keep a local reference to be returned if possible.
        self._local_instance = data

        if pickle:
            if data is not None:
                # We may not be able to pickle the data.
                try:
                    pdata = dumps(data)
                    # This format (as opposed to using a single sequence) allows
                    # the type to be extracted without unpickling the data.
                    self.setData(self.MIME_TYPE, dumps(data.__class__) + pdata)
                except (PickleError, TypeError, AttributeError):
                    # if pickle fails, still try to create a draggable
                    warnings.warn(("Could not pickle dragged object %s, " +
                            "using %s mimetype instead") % (repr(data),
                            self.NOPICKLE_MIME_TYPE), RuntimeWarning)
                    self.setData(self.NOPICKLE_MIME_TYPE, str2bytes(str(id(data))))

        else:
            self.setData(self.NOPICKLE_MIME_TYPE, str2bytes(str(id(data))))
Example #5
0
def writePickleZip (outputFile, data, log=None) :
    '''Utility to write a pickle to disk.

       NOTE: This first attempts to use protocol 2 for greater portability 
             across python versions. If that fails it, and we're using py3,
             we attempt again with the highest protocol available. Advances in
             py3.x allow large datasets to be pickled.

       outputFile : Name of the file to write. The extension should be pkl.gz
       data       : Data to write to the file
       log        : Logger to use
    '''
    if not outputFile.endswith('pkl.gz') :
        raise Exception('The file must end in the pkl.gz extension.')
    if log is not None :
        log.info('Compressing to [' + outputFile + ']')

    # attempt to pickle with protocol 2 --
    # protocol 2 is the highest protocol supported in py2.x. If we can
    # get away with pickling with this protocol, it will provide better
    # portability across python releases.
    try :
        with gzip.open(outputFile, 'wb') as f :
            f.write(cPickle.dumps(data, protocol=2))

    # TODO: find exact error thrown while pickling large networks
    except Exception as ex :
        import sys
        # large objects cannot be pickled in py2.x, so if we're using py3.x,
        # let's attempt the pickle again with the current highest.
        if sys.version_info >= (3, 0) :
            with gzip.open(outputFile.replace('pkl', 'pkl3'), 'wb') as f :
                f.write(cPickle.dumps(data, protocol=cPickle.HIGHEST_PROTOCOL))
        else : raise ex
Example #6
0
def test_run_in_runclassmethod():
    class C(object):
        @classmethod
        def fn(cls, *args, **kwargs):
            return cls, args, kwargs
    C.__name__ = 'C_' + str(uuid4()).replace('-', '_')
    C.__module__ = 'pytest_shutil.run'
    with patch('pytest_shutil.run.execnet') as execnet:
        gw = execnet.makegateway.return_value
        chan = gw.remote_exec.return_value
        chan.receive.return_value = cPickle.dumps(sentinel.ret)
        c = C()
        with patch.object(run, C.__name__, C, create=True):
            run.run_in_subprocess(c.fn, python='sentinel.python')(ARG, kw=KW)
            ((s,), _) = chan.send.call_args
            if sys.version_info < (3, 0, 0):
                # Class methods are not pickleable in Python 2.
                assert cPickle.loads(s) == (run._invoke_method, (C, 'fn', ARG), {'kw': KW})
            else:
                # Class methods are pickleable in Python 3.
                assert cPickle.loads(s) == (c.fn, (ARG,), {'kw': KW})
            ((remote_fn,), _) = gw.remote_exec.call_args
            ((chan.receive.return_value,), _) = chan.send.call_args
            remote_fn(chan)
            chan.send.assert_called_with(cPickle.dumps((C, (ARG,), {'kw': KW}), protocol=0))
Example #7
0
 def test_pickle(self):
     md = PyMimeData(data=0)
     self.assertEqual(md._local_instance, 0)
     self.assertTrue(md.hasFormat(PyMimeData.MIME_TYPE))
     self.assertFalse(md.hasFormat(PyMimeData.NOPICKLE_MIME_TYPE))
     self.assertEqual(md.data(PyMimeData.MIME_TYPE).data(),
                      dumps(int)+dumps(0))
Example #8
0
    def start(self, cache, job_specs, callback):
        """Run jobs on the backend, blocking until their completion.

        Args:
            cache: The persistent cache which should be set on the backend
            job_specs: The job specification (see
                owls_parallel.backends.ParallelizationBackend)
            callback: The job notification callback, not used by this backend
        """
        # Create the result list
        results = []

        # Go through each job and create a batch job for it
        for spec in itervalues(job_specs):
            # Create the job content
            batch_script = _BATCH_TEMPLATE.format(**{
                "cache": b64encode(dumps(cache)),
                "job": b64encode(dumps(spec)),
            })

            # Create an on-disk handle
            script_name = '{0}.py'.format(uuid4().hex)
            script_path = join(self._path, script_name)

            # Write it to file
            with open(script_path, 'w') as f:
                f.write(batch_script)

            # Submit the batch job and record the job id
            results.append(self._submit(self._path, script_name))

        # All done
        return results
 def test_pickling(self):
     """intbitset - pickling"""
     from six.moves import cPickle
     for set1 in self.sets + [[]]:
         self.assertEqual(self.intbitset(set1), cPickle.loads(cPickle.dumps(self.intbitset(set1), -1)))
     for set1 in self.sets + [[]]:
         self.assertEqual(self.intbitset(set1, trailing_bits=True), cPickle.loads(cPickle.dumps(self.intbitset(set1, trailing_bits=True), -1)))
Example #10
0
def _make_pickleable(fn):
    # return a pickleable function followed by a tuple of initial arguments
    # could use partial but this is more efficient
    try:
        cPickle.dumps(fn, protocol=0)
    except TypeError:
        pass
    else:
        return fn, ()
    if inspect.ismethod(fn):
        name, self_ = fn.__name__, fn.__self__
        if self_ is None:  # Python 2 unbound method
            self_ = fn.im_class
        return _invoke_method, (self_, name)
    elif inspect.isfunction(fn) and fn.__module__ in sys.modules:
        cls, name = _find_class_from_staticmethod(fn)
        if (cls, name) != (None, None):
            try:
                cPickle.dumps((cls, name), protocol=0)
            except cPickle.PicklingError:
                pass
            else:
                return _invoke_method, (cls, name)
    # Fall back to sending the source code
    return _evaluate_fn_source, (textwrap.dedent(inspect.getsource(fn)),)
Example #11
0
def test_store_update_context():
    src.val.nested = 'nested{{src.ccdid}}'
    src.val['ccdid'] = 2
    src['srcdir'] = 'obs{{ src.obsid }}/{{src.nested}}'
    files['srcdir'] = '{{ src.srcdir }}'
    src['obsid'] = 123
    src.val['ccdid'] = 2
    src['ra'] = 1.4343256789
    src['ra'].format = '%.4f'
    files['evt2'] = 'obs{{ src.obsid }}/{{src.nested}}/acis_evt2'
    src.val.nested = 'nested{{src.ccdid}}'

    tmp = pickle.dumps(src)
    tmp2 = pickle.dumps(files)

    src.clear()
    files.clear()

    assert src['ra'].val is None
    assert files['evt2'].val is None

    src.update(pickle.loads(tmp))
    files.update(pickle.loads(tmp2))

    assert str(src['ra']) == '1.4343'
    assert str(src['srcdir']) == 'obs123/nested2'
    assert files['srcdir'].rel == 'data/obs123/nested2'
    assert files.rel.srcdir == 'data/obs123/nested2'
    assert str(files['evt2.fits']) == 'data/obs123/nested2/acis_evt2.fits'
Example #12
0
File: hash.py Project: Javacym/jug
def hash_update(M, elems):
    '''
    M = hash_update(M, elems)

    Update the hash object ``M`` with the sequence ``elems``.

    Parameters
    ----------
    M : hashlib object
        An object on which the update method will be called
    elems : sequence of 2-tuples

    Returns
    -------
    M : hashlib object
        This is the same object as the argument
    '''
    from six.moves import cPickle as pickle
    from six.moves import map
    import six

    try:
        import numpy as np
    except ImportError:
        np = None
    for n,e in elems:
        M.update(pickle.dumps(n))
        if hasattr(e, '__jug_hash__'):
            M.update(e.__jug_hash__())
        elif type(e) in (list, tuple):
            M.update(repr(type(e)).encode('utf-8'))
            hash_update(M, enumerate(e))
        elif type(e) == set:
            M.update('set')
            # With randomized hashing, different runs of Python might result in
            # different orders, so sort. We cannot trust that all the elements
            # in the set will be comparable, so we convert them to their hashes
            # beforehand.
            items = list(map(hash_one, e))
            items.sort()
            hash_update(M, enumerate(items))
        elif type(e) == dict:
            M.update(six.b('dict'))
            items = [(hash_one(k),v) for k,v in e.items()]
            items.sort(key=(lambda k_v:k_v[0]))

            hash_update(M, items)
        elif np is not None and type(e) == np.ndarray:
            M.update(six.b('np.ndarray'))
            M.update(pickle.dumps(e.dtype))
            M.update(pickle.dumps(e.shape))
            try:
                buffer = e.data
                M.update(buffer)
            except:
                M.update(e.copy().data)
        else:
            M.update(pickle.dumps(e))
    return M
Example #13
0
 def test_coerce_list_pymimedata(self):
     md = PyMimeData(data=0)
     md2 = PyMimeData.coerce([md])
     self.assertEqual(md2._local_instance, [0])
     self.assertTrue(md2.hasFormat(PyMimeData.MIME_TYPE))
     self.assertFalse(md2.hasFormat(PyMimeData.NOPICKLE_MIME_TYPE))
     self.assertEqual(md2.data(PyMimeData.MIME_TYPE).data(),
                      dumps(list)+dumps([0]))
Example #14
0
    def test_get(self):
        #mock the pick and set it to the data variable
        test_pickle = pickle.dumps(
            {pickle.dumps(self.test_key): self.test_value}, protocol=2)
        self.test_cache.data = pickle.loads(test_pickle)

        #assert
        self.assertEquals(self.test_cache.get(self.test_key), self.test_value)
        self.assertEquals(self.test_cache.get(self.bad_key), None)
 def _encodeMetadata(self, metadata):
     # metadata format is:
     #    - first line with trailing \x00: comment or empty comment
     #    - then: pickled metadata (incl. comment)
     try:
         comment = metadata['sys_metadata']['comment']
         comment = dumps(comment)
     except KeyError:
         comment = ''
     return b'\x00\n'.join((comment, dumps(metadata, HIGHEST_PROTOCOL)))
Example #16
0
    def test_serialization(self):
        e = smqtk.representation.classification_element.memory\
            .MemoryClassificationElement('test', 0)

        e2 = cPickle.loads(cPickle.dumps(e))
        self.assertEqual(e, e2)

        e.set_classification(a=0, b=1)
        e2 = cPickle.loads(cPickle.dumps(e))
        self.assertEqual(e, e2)
 def test_get(self):
     env = Envelope('*****@*****.**', ['*****@*****.**', '*****@*****.**'])
     envelope_raw = cPickle.dumps(env)
     delivered_indexes_raw = cPickle.dumps([0])
     self.storage.redis.hmget('test:asdf', 'envelope', 'attempts', 'delivered_indexes').AndReturn((envelope_raw, 13, delivered_indexes_raw))
     self.mox.ReplayAll()
     get_env, attempts = self.storage.get('asdf')
     self.assertEqual('*****@*****.**', get_env.sender)
     self.assertEqual(['*****@*****.**'], get_env.recipients)
     self.assertEqual(13, attempts)
Example #18
0
def save_sesh(dict_objs, file='skrfSesh.p', module='skrf', exclude_prefix='_'):
    '''
    Save all `skrf` objects in the local namespace.

    This is used to save current workspace in a hurry, by passing it the
    output of :func:`locals` (see Examples). Note this can be
    used for other modules as well by passing a different `module` name.

    Parameters
    ------------
    dict_objs : dict
        dictionary containing `skrf` objects. See the Example.
    file : str or file-object, optional
        the file to save all objects to
    module : str, optional
        the module name to grep for.
    exclude_prefix: str, optional
        dont save objects which have this as a prefix.

    See Also
    ----------
    read : read a skrf object
    write : write skrf object[s]
    read_all : read all skrf objects in a directory
    write_all : write dictionary of skrf objects to a directory


    Examples
    ---------
    Write out all skrf objects in current namespace.

    >>> rf.write_all(locals(), 'mysesh.p')


    '''
    objects = {}
    print('pickling: ')
    for k in dict_objs:
        try:
            if module  in inspect.getmodule(dict_objs[k]).__name__:
                try:
                    pickle.dumps(dict_objs[k])
                    if k[0] != '_':
                        objects[k] = dict_objs[k]
                        print(k+', ')
                finally:
                    pass

        except(AttributeError, TypeError):
            pass
    if len (objects ) == 0:
        print('nothing')

    write(file, objects)
Example #19
0
def assert_pickle_idempotent(obj):
    '''Assert that obj does not change (w.r.t. ==) under repeated picklings
    '''
    from six.moves.cPickle import dumps, loads
    obj1 = loads(dumps(obj))
    obj2 = loads(dumps(obj1))
    obj3 = loads(dumps(obj2))
    assert_equivalent(obj, obj1)
    assert_equivalent(obj, obj2)
    assert_equivalent(obj, obj3)
    assert type(obj) is type(obj3)
Example #20
0
 def _set_object_data(self, data):
     if cb.Open():
         try:
             cdo = wx.CustomDataObject(PythonObjectFormat)
             cdo.SetData(dumps(data.__class__) + dumps(data))
             # fixme: There seem to be cases where the '-1' value creates
             # pickles that can't be unpickled (e.g. some TraitDictObject's)
             #cdo.SetData(dumps(data, -1))
             cb.SetData(cdo)
         finally:
             cb.Close()
             cb.Flush()
Example #21
0
File: task.py Project: douban/dpark
 def _prepare(self, items):
     items = list(items)
     try:
         if marshalable(items):
             is_marshal, d = True, marshal.dumps(items)
         else:
             is_marshal, d = False, cPickle.dumps(items, -1)
     except ValueError:
         is_marshal, d = False, cPickle.dumps(items, -1)
     data = compress(d)
     size = len(data)
     return (is_marshal, data), size
Example #22
0
 def setUp(self):
     """Set up hopefully unique universes."""
     # _n marks named universes/atomgroups/pickled strings
     self.universe = mda.Universe(PDB_small, PDB_small, PDB_small)
     self.universe_n = mda.Universe(PDB_small, PDB_small, PDB_small,
                                    anchor_name="test1")
     self.ag = self.universe.atoms[:20]  # prototypical AtomGroup
     self.ag_n = self.universe_n.atoms[:10]
     self.pickle_str = cPickle.dumps(self.ag,
                                     protocol=cPickle.HIGHEST_PROTOCOL)
     self.pickle_str_n = cPickle.dumps(self.ag_n,
                                       protocol=cPickle.HIGHEST_PROTOCOL)
Example #23
0
 def test_pickling(self):
     try:
         self.stream = cPickle.loads(cPickle.dumps(self.stream))
         # regression test: pickling of an unpickled stream used it fail
         cPickle.dumps(self.stream)
         server_data = self.stream.get_epoch_iterator()
         expected_data = get_stream().get_epoch_iterator()
         for _, s, e in zip(range(3), server_data, expected_data):
             for data in zip(s, e):
                 assert_allclose(*data, rtol=1e-3)
     except AssertionError as e:
         raise SkipTest("Skip test_that failed with: {}".format(e))
     assert_raises(StopIteration, next, server_data)
Example #24
0
def run_task(task_data):
    try:
        gc.disable()
        task, task_try_id = loads(decompress(task_data))
        ttid = TTID(task_try_id)
        Accumulator.clear()
        result = task.run(ttid.ttid)
        env.task_stats.bytes_max_rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss * 1024
        accUpdate = Accumulator.values()
        MutableDict.flush()

        if marshalable(result):
            try:
                flag, data = 0, marshal.dumps(result)
            except Exception:
                flag, data = 1, cPickle.dumps(result, -1)

        else:
            flag, data = 1, cPickle.dumps(result, -1)
        data = compress(data)

        if len(data) > TASK_RESULT_LIMIT:
            # shuffle_id start from 1
            swd = ShuffleWorkDir(0, task.id, ttid.task_try)
            tmppath = swd.alloc_tmp(len(data))
            with open(tmppath, 'wb') as f:
                f.write(data)
                f.close()
            path = swd.export(tmppath)
            data = '/'.join(
                [env.server_uri] + path.split('/')[-3:]
            )
            flag += 2

        return TaskState.finished, cPickle.dumps(((flag, data), accUpdate, env.task_stats), -1)
    except FetchFailed as e:
        return TaskState.failed, TaskEndReason.fetch_failed, str(e), cPickle.dumps(e)
    except Exception as e:
        import traceback
        msg = traceback.format_exc()
        ename = e.__class__.__name__
        fatal_exceptions = (DparkUserFatalError, ArithmeticError,
                            ValueError, LookupError, SyntaxError,
                            TypeError, AssertionError)
        prefix = "FATAL" if isinstance(e, fatal_exceptions) else "FAILED"
        return TaskState.failed, '{}_EXCEPTION_{}'.format(prefix, ename), msg, cPickle.dumps(e)
    finally:
        gc.collect()
        gc.enable()
Example #25
0
def dbsafe_encode(value, compress_object=False):
    """
    We use deepcopy() here to avoid a problem with cPickle, where dumps
    can generate different character streams for same lookup value if
    they are referenced differently.

    The reason this is important is because we do all of our lookups as
    simple string matches, thus the character streams must be the same
    for the lookups to work properly. See tests.py for more information.
    """
    if not compress_object:
        value = b64encode(dumps(deepcopy(value)))
    else:
        value = b64encode(compress(dumps(deepcopy(value))))
    return PickledObject(value)
Example #26
0
def test_run_in_runstr():
    source = """def fn(*args, **kwargs):
    return args, kwargs
"""
    with patch('pytest_shutil.run.execnet') as execnet:
        gw = execnet.makegateway.return_value
        chan = gw.remote_exec.return_value
        chan.receive.return_value = cPickle.dumps(sentinel.ret)
        run.run_in_subprocess(source, python='sentinel.python')(ARG, kw=KW)
        ((s,), _) = chan.send.call_args
        assert cPickle.loads(s) == (run._evaluate_fn_source, (source, ARG,), {'kw': KW})
        ((remote_fn,), _) = gw.remote_exec.call_args
        ((chan.receive.return_value,), _) = chan.send.call_args
        remote_fn(chan)
        chan.send.assert_called_with(cPickle.dumps(((ARG,), {'kw': KW}), protocol=0))
 def write(self, envelope, timestamp):
     envelope_raw = cPickle.dumps(envelope, cPickle.HIGHEST_PROTOCOL)
     while True:
         id = uuid.uuid4().hex
         key = self._get_key(id)
         if self.redis.hsetnx(key, 'envelope', envelope_raw):
             queue_raw = cPickle.dumps((timestamp, id),
                                       cPickle.HIGHEST_PROTOCOL)
             pipe = self.redis.pipeline()
             pipe.hmset(key, {'timestamp': timestamp,
                              'attempts': 0})
             pipe.rpush(self.queue_key, queue_raw)
             pipe.execute()
             log.write(id, envelope)
             return id
Example #28
0
    def test_post_classifier_failures(self):
        pickle_data = pickle.dumps(DummyClassifier.from_config({}))
        enc_data = base64.b64encode(pickle_data)
        bad_data = base64.b64encode(pickle.dumps(object()))
        old_label = 'dummy'
        new_label = 'dummy2'
        lock_clfr_str = '['

        with self.app.test_client() as cli:
            rv = cli.post('/classifier', data={
                'label': old_label,
                'bytes_b64': enc_data,
            })
            self.assertStatus(rv, 400)
            resp_data = json.loads(rv.data.decode())
            self.assertMessage(resp_data,
                               "Label '%s' already exists in classifier"
                               " collection." % old_label)
            self.assertEqual(resp_data['label'], old_label)

            rv = cli.post('/classifier', data={'label': old_label})
            self.assertStatus(rv, 400)
            self.assertMessage(json.loads(rv.data.decode()),
                               "No state base64 data provided.")

            rv = cli.post('/classifier', data={'bytes_b64': enc_data})
            self.assertStatus(rv, 400)
            self.assertMessage(json.loads(rv.data.decode()),
                               "No descriptive label provided.")

            rv = cli.post('/classifier', data={
                'label': old_label,
                'lock_label': lock_clfr_str,
                'bytes_b64': enc_data,
            })
            self.assertStatus(rv, 400)
            self.assertMessage(json.loads(rv.data.decode()),
                               "Invalid boolean value for 'lock_label'."
                               " Was given: '%s'" % lock_clfr_str)

            rv = cli.post('/classifier', data={
                'label': new_label,
                'bytes_b64': bad_data,
            })
            self.assertStatus(rv, 400)
            self.assertMessage(json.loads(rv.data.decode()),
                               "Data added for label '%s' is not a"
                               " Classifier." % new_label)
Example #29
0
 def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
     self._createdir()  # Cache dir can be deleted at any time.
     fname = self._key_to_file(key, version)
     self._cull()  # make some room if necessary
     fd, tmp_path = tempfile.mkstemp(dir=self._dir)
     renamed = False
     try:
         with io.open(fd, 'wb') as f:
             expiry = self.get_backend_timeout(timeout)
             f.write(pickle.dumps(expiry, -1))
             f.write(zlib.compress(pickle.dumps(value), -1))
         file_move_safe(tmp_path, fname, allow_overwrite=True)
         renamed = True
     finally:
         if not renamed:
             os.remove(tmp_path)
Example #30
0
    def set_extra_data(self, value):
        """Save extra data to the object.

        :param value: what you want to replace extra_data with.
        :type value: dict
        """
        self._extra_data = base64.b64encode(cPickle.dumps(value))
Example #31
0
 def testSerialization(self):
   desc = specs.BoundedArray([1, 5], np.float32, -1, 1, "test")
   self.assertEqual(pickle.loads(pickle.dumps(desc)), desc)
Example #32
0
def serialize_to_string(data):
    """
    Dump arbitrary Python object `data` to a string that is base64 encoded
    pickle data.
    """
    return binascii.b2a_base64(pickle.dumps(data)).decode('utf-8')
Example #33
0
 def test_pickling(self):
     dataset = cPickle.loads(cPickle.dumps(self.dataset))
     assert_equal(len(dataset.nodes), 1)
Example #34
0
class CoreTest(parameterized.TestCase):
    def setUp(self):
        super(CoreTest, self).setUp()
        self.model = core.MjModel.from_xml_path(HUMANOID_XML_PATH)
        self.data = core.MjData(self.model)

    def _assert_attributes_equal(self, actual_obj, expected_obj,
                                 attr_to_compare):
        for name in attr_to_compare:
            actual_value = getattr(actual_obj, name)
            expected_value = getattr(expected_obj, name)
            try:
                if isinstance(expected_value, np.ndarray):
                    np.testing.assert_array_equal(actual_value, expected_value)
                else:
                    self.assertEqual(actual_value, expected_value)
            except AssertionError as e:
                self.fail(
                    "Attribute '{}' differs from expected value: {}".format(
                        name, str(e)))

    def testLoadXML(self):
        with open(HUMANOID_XML_PATH, "r") as f:
            xml_string = f.read()
        model = core.MjModel.from_xml_string(xml_string)
        core.MjData(model)
        with self.assertRaises(TypeError):
            core.MjModel()
        with self.assertRaises(core.Error):
            core.MjModel.from_xml_path("/path/to/nonexistent/model/file.xml")

        xml_with_warning = """
        <mujoco>
          <size njmax='2'/>
            <worldbody>
              <body pos='0 0 0'>
                <geom type='box' size='.1 .1 .1'/>
              </body>
              <body pos='0 0 0'>
                <joint type='slide' axis='1 0 0'/>
                <geom type='box' size='.1 .1 .1'/>
              </body>
            </worldbody>
        </mujoco>"""

        # This model should compile successfully, but raise a warning on the first
        # simulation step.
        model = core.MjModel.from_xml_string(xml_with_warning)
        data = core.MjData(model)
        with mock.patch.object(core, "logging") as mock_logging:
            mjlib.mj_step(model.ptr, data.ptr)
        mock_logging.warning.assert_called_once_with(
            "Pre-allocated constraint buffer is full. Increase njmax above 2. "
            "Time = 0.0000.")

    def testLoadXMLWithAssetsFromString(self):
        core.MjModel.from_xml_string(MODEL_WITH_ASSETS, assets=ASSETS)
        with self.assertRaises(core.Error):
            # Should fail to load without the assets
            core.MjModel.from_xml_string(MODEL_WITH_ASSETS)

    def testVFSFilenameTooLong(self):
        limit = core._MAX_VFS_FILENAME_CHARACTERS
        contents = "fake contents"
        valid_filename = "a" * limit
        with core._temporary_vfs({valid_filename: contents}):
            pass
        invalid_filename = "a" * (limit + 1)
        expected_message = core._VFS_FILENAME_TOO_LONG.format(
            length=(limit + 1), limit=limit, filename=invalid_filename)
        with self.assertRaisesWithLiteralMatch(ValueError, expected_message):
            with core._temporary_vfs({invalid_filename: contents}):
                pass

    def testSaveLastParsedModelToXML(self):
        save_xml_path = os.path.join(OUT_DIR, "tmp_humanoid.xml")

        not_last_parsed = core.MjModel.from_xml_path(HUMANOID_XML_PATH)
        last_parsed = core.MjModel.from_xml_path(HUMANOID_XML_PATH)

        # Modify the model before saving it in order to confirm that the changes are
        # written to the XML.
        last_parsed.geom_pos.flat[:] = np.arange(last_parsed.geom_pos.size)

        core.save_last_parsed_model_to_xml(save_xml_path,
                                           check_model=last_parsed)

        loaded = core.MjModel.from_xml_path(save_xml_path)
        self._assert_attributes_equal(last_parsed, loaded, ["geom_pos"])
        core.MjData(loaded)

        # Test that `check_model` results in a ValueError if it is not the most
        # recently parsed model.
        with self.assertRaisesWithLiteralMatch(ValueError,
                                               core._NOT_LAST_PARSED_ERROR):
            core.save_last_parsed_model_to_xml(save_xml_path,
                                               check_model=not_last_parsed)

    def testBinaryIO(self):
        bin_path = os.path.join(OUT_DIR, "tmp_humanoid.mjb")
        self.model.save_binary(bin_path)
        core.MjModel.from_binary_path(bin_path)
        byte_string = self.model.to_bytes()
        core.MjModel.from_byte_string(byte_string)

    def testDimensions(self):
        self.assertEqual(self.data.qpos.shape[0], self.model.nq)
        self.assertEqual(self.data.qvel.shape[0], self.model.nv)
        self.assertEqual(self.model.body_pos.shape, (self.model.nbody, 3))

    def testStep(self):
        t0 = self.data.time
        mjlib.mj_step(self.model.ptr, self.data.ptr)
        self.assertEqual(self.data.time, t0 + self.model.opt.timestep)
        self.assertTrue(np.all(np.isfinite(self.data.qpos[:])))
        self.assertTrue(np.all(np.isfinite(self.data.qvel[:])))

    def testMultipleData(self):
        data2 = core.MjData(self.model)
        self.assertNotEqual(self.data.ptr, data2.ptr)
        t0 = self.data.time
        mjlib.mj_step(self.model.ptr, self.data.ptr)
        self.assertEqual(self.data.time, t0 + self.model.opt.timestep)
        self.assertEqual(data2.time, 0)

    def testMultipleModel(self):
        model2 = core.MjModel.from_xml_path(HUMANOID_XML_PATH)
        self.assertNotEqual(self.model.ptr, model2.ptr)
        self.model.opt.timestep += 0.001
        self.assertEqual(self.model.opt.timestep, model2.opt.timestep + 0.001)

    def testModelName(self):
        self.assertEqual(self.model.name, "humanoid")

    @parameterized.named_parameters(
        ("_copy", lambda x: x.copy()),
        ("_pickle_unpickle", lambda x: cPickle.loads(cPickle.dumps(x))),
    )
    def testCopyOrPickleModel(self, func):
        timestep = 0.12345
        self.model.opt.timestep = timestep
        body_pos = self.model.body_pos + 1
        self.model.body_pos[:] = body_pos
        model2 = func(self.model)
        self.assertNotEqual(model2.ptr, self.model.ptr)
        self.assertEqual(model2.opt.timestep, timestep)
        np.testing.assert_array_equal(model2.body_pos, body_pos)

    @parameterized.named_parameters(
        ("_copy", lambda x: x.copy()),
        ("_pickle_unpickle", lambda x: cPickle.loads(cPickle.dumps(x))),
    )
    def testCopyOrPickleData(self, func):
        for _ in range(10):
            mjlib.mj_step(self.model.ptr, self.data.ptr)
        data2 = func(self.data)
        attr_to_compare = ("time", "energy", "qpos", "xpos")
        self.assertNotEqual(data2.ptr, self.data.ptr)
        self._assert_attributes_equal(data2, self.data, attr_to_compare)
        for _ in range(10):
            mjlib.mj_step(self.model.ptr, self.data.ptr)
            mjlib.mj_step(data2.model.ptr, data2.ptr)
        self._assert_attributes_equal(data2, self.data, attr_to_compare)

    @parameterized.named_parameters(
        ("_copy", lambda x: x.copy()),
        ("_pickle_unpickle", lambda x: cPickle.loads(cPickle.dumps(x))),
    )
    def testCopyOrPickleStructs(self, func):
        for _ in range(10):
            mjlib.mj_step(self.model.ptr, self.data.ptr)
        data2 = func(self.data)
        self.assertNotEqual(data2.ptr, self.data.ptr)
        attr_to_compare = ("warning", "timer", "solver")
        self._assert_attributes_equal(self.data, data2, attr_to_compare)
        for _ in range(10):
            mjlib.mj_step(self.model.ptr, self.data.ptr)
            mjlib.mj_step(data2.model.ptr, data2.ptr)
        self._assert_attributes_equal(self.data, data2, attr_to_compare)

    @parameterized.parameters(("right_foot", "body", 6),
                              ("right_foot", enums.mjtObj.mjOBJ_BODY, 6),
                              ("left_knee", "joint", 11),
                              ("left_knee", enums.mjtObj.mjOBJ_JOINT, 11))
    def testNamesIds(self, name, object_type, object_id):
        output_id = self.model.name2id(name, object_type)
        self.assertEqual(object_id, output_id)
        output_name = self.model.id2name(object_id, object_type)
        self.assertEqual(name, output_name)

    def testNamesIdsExceptions(self):
        with six.assertRaisesRegex(self, core.Error, "does not exist"):
            self.model.name2id("nonexistent_body_name", "body")
        with six.assertRaisesRegex(self, core.Error,
                                   "is not a valid object type"):
            self.model.name2id("right_foot", "nonexistent_type_name")

    def testNamelessObject(self):
        # The model in humanoid.xml contains a single nameless camera.
        name = self.model.id2name(0, "camera")
        self.assertEqual("", name)

    def testWarningCallback(self):
        self.data.qpos[0] = np.inf
        with mock.patch.object(core, "logging") as mock_logging:
            mjlib.mj_step(self.model.ptr, self.data.ptr)
        mock_logging.warning.assert_called_once_with(
            "Nan, Inf or huge value in QPOS at DOF 0. The simulation is unstable. "
            "Time = 0.0000.")

    def testErrorCallback(self):
        with mock.patch.object(core, "logging") as mock_logging:
            mjlib.mj_activate(b"nonexistent_activation_key")
        mock_logging.fatal.assert_called_once_with(
            "Could not open activation key file nonexistent_activation_key")

    def testSingleCallbackContext(self):

        callback_was_called = [False]

        def callback(unused_model, unused_data):
            callback_was_called[0] = True

        mjlib.mj_step(self.model.ptr, self.data.ptr)
        self.assertFalse(callback_was_called[0])

        class DummyError(RuntimeError):
            pass

        try:
            with core.callback_context("mjcb_passive", callback):

                # Stepping invokes the `mjcb_passive` callback.
                mjlib.mj_step(self.model.ptr, self.data.ptr)
                self.assertTrue(callback_was_called[0])

                # Exceptions should not prevent `mjcb_passive` from being reset.
                raise DummyError("Simulated exception.")

        except DummyError:
            pass

        # `mjcb_passive` should have been reset to None.
        callback_was_called[0] = False
        mjlib.mj_step(self.model.ptr, self.data.ptr)
        self.assertFalse(callback_was_called[0])

    def testNestedCallbackContexts(self):

        last_called = [None]
        outer_called = "outer called"
        inner_called = "inner called"

        def outer(unused_model, unused_data):
            last_called[0] = outer_called

        def inner(unused_model, unused_data):
            last_called[0] = inner_called

        with core.callback_context("mjcb_passive", outer):

            # This should execute `outer` a few times.
            mjlib.mj_step(self.model.ptr, self.data.ptr)
            self.assertEqual(last_called[0], outer_called)

            with core.callback_context("mjcb_passive", inner):

                # This should execute `inner` a few times.
                mjlib.mj_step(self.model.ptr, self.data.ptr)
                self.assertEqual(last_called[0], inner_called)

            # When we exit the inner context, the `mjcb_passive` callback should be
            # reset to `outer`.
            mjlib.mj_step(self.model.ptr, self.data.ptr)
            self.assertEqual(last_called[0], outer_called)

        # When we exit the outer context, the `mjcb_passive` callback should be
        # reset to None, and stepping should not affect `last_called`.
        last_called[0] = None
        mjlib.mj_step(self.model.ptr, self.data.ptr)
        self.assertIsNone(last_called[0])

    def testDisableFlags(self):
        xml_string = """
    <mujoco>
      <option gravity="0 0 -9.81"/>
      <worldbody>
        <geom name="floor" type="plane" pos="0 0 0" size="10 10 0.1"/>
        <body name="cube" pos="0 0 0.1">
          <geom type="box" size="0.1 0.1 0.1" mass="1"/>
          <site name="cube_site" type="box" size="0.1 0.1 0.1"/>
          <joint type="slide"/>
        </body>
      </worldbody>
      <sensor>
        <touch name="touch_sensor" site="cube_site"/>
      </sensor>
    </mujoco>
    """
        model = core.MjModel.from_xml_string(xml_string)
        data = core.MjData(model)
        for _ in range(100):  # Let the simulation settle for a while.
            mjlib.mj_step(model.ptr, data.ptr)

        # With gravity and contact enabled, the cube should be stationary and the
        # touch sensor should give a reading of ~9.81 N.
        self.assertAlmostEqual(data.qvel[0], 0, places=4)
        self.assertAlmostEqual(data.sensordata[0], 9.81, places=2)

        # If we disable both contacts and gravity then the cube should remain
        # stationary and the touch sensor should read zero.
        with model.disable("contact", "gravity"):
            mjlib.mj_step(model.ptr, data.ptr)
        self.assertAlmostEqual(data.qvel[0], 0, places=4)
        self.assertEqual(data.sensordata[0], 0)

        # If we disable contacts but not gravity then the cube should fall through
        # the floor.
        with model.disable(enums.mjtDisableBit.mjDSBL_CONTACT):
            for _ in range(10):
                mjlib.mj_step(model.ptr, data.ptr)
        self.assertLess(data.qvel[0], -0.1)

    def testDisableFlagsExceptions(self):
        with six.assertRaisesRegex(self, ValueError, "not a valid flag name"):
            with self.model.disable("invalid_flag_name"):
                pass
        with six.assertRaisesRegex(self, ValueError,
                                   "not a value in `enums.mjtDisableBit`"):
            with self.model.disable(-99):
                pass

    @parameterized.named_parameters(
        ("MjModel", lambda _: core.MjModel.from_xml_path(HUMANOID_XML_PATH),
         "mj_deleteModel"),
        ("MjData", lambda self: core.MjData(self.model), "mj_deleteData"),
        ("MjvScene", lambda _: core.MjvScene(), "mjv_freeScene"))
    def testFree(self, constructor, destructor_name):
        for _ in range(5):
            destructor = getattr(mjlib, destructor_name)
            with mock.patch.object(core.mjlib,
                                   destructor_name,
                                   wraps=destructor) as mock_destructor:
                wrapper = constructor(self)

            expected_address = ctypes.addressof(wrapper.ptr.contents)
            wrapper.free()
            self.assertIsNone(wrapper.ptr)

            mock_destructor.assert_called_once()
            pointer = mock_destructor.call_args[0][0]
            actual_address = ctypes.addressof(pointer.contents)
            self.assertEqual(expected_address, actual_address)

            # Explicit freeing should not break any automatic GC triggered later.
            del wrapper
            gc.collect()

    @parameterized.parameters(
        # The tip is .5 meters from the cart so we expect its horizontal velocity
        # to be 1m/s + .5m*1rad/s = 1.5m/s.
        dict(
            qpos=[0., 0.],  # Pole pointing upwards.
            qvel=[1., 1.],
            expected_linvel=[1.5, 0., 0.],
            expected_angvel=[0., 1., 0.],
        ),
        # For the same velocities but with the pole pointing down, we expect the
        # velocities to cancel, making the global tip velocity now equal to
        # 1m/s - 0.5m*1rad/s = 0.5m/s.
        dict(
            qpos=[0., np.pi],  # Pole pointing downwards.
            qvel=[1., 1.],
            expected_linvel=[0.5, 0., 0.],
            expected_angvel=[0., 1., 0.],
        ),
        # In the site's local frame, which is now flipped w.r.t the world, the
        # velocity is in the negative x direction.
        dict(
            qpos=[0., np.pi],  # Pole pointing downwards.
            qvel=[1., 1.],
            expected_linvel=[-0.5, 0., 0.],
            expected_angvel=[0., 1., 0.],
            local=True,
        ),
    )
    def testObjectVelocity(self,
                           qpos,
                           qvel,
                           expected_linvel,
                           expected_angvel,
                           local=False):
        cartpole = """
    <mujoco>
      <worldbody>
        <body name='cart'>
          <joint type='slide' axis='1 0 0'/>
          <geom name='cart' type='box' size='0.2 0.2 0.2'/>
          <body name='pole'>
            <joint name='hinge' type='hinge' axis='0 1 0'/>
            <geom name='mass' pos='0 0 .5' size='0.04'/>
          </body>
        </body>
      </worldbody>
    </mujoco>
    """
        model = core.MjModel.from_xml_string(cartpole)
        data = core.MjData(model)
        data.qpos[:] = qpos
        data.qvel[:] = qvel
        mjlib.mj_step1(model.ptr, data.ptr)
        linvel, angvel = data.object_velocity("mass",
                                              "geom",
                                              local_frame=local)
        np.testing.assert_array_almost_equal(linvel, expected_linvel)
        np.testing.assert_array_almost_equal(angvel, expected_angvel)

    def testFreeMjrContext(self):
        for _ in range(5):
            renderer = _render.Renderer(640, 480)
            with mock.patch.object(
                    core.mjlib, "mjr_freeContext",
                    wraps=mjlib.mjr_freeContext) as mock_destructor:
                mjr_context = core.MjrContext(self.model, renderer)
                expected_address = ctypes.addressof(mjr_context.ptr.contents)
                mjr_context.free()

            self.assertIsNone(mjr_context.ptr)
            mock_destructor.assert_called_once()
            pointer = mock_destructor.call_args[0][0]
            actual_address = ctypes.addressof(pointer.contents)
            self.assertEqual(expected_address, actual_address)

            # Explicit freeing should not break any automatic GC triggered later.
            del mjr_context
            renderer.free()
            del renderer
            gc.collect()

    def testSceneGeomsAttribute(self):
        scene = core.MjvScene(model=self.model)
        self.assertEqual(scene.ngeom, 0)
        self.assertEmpty(scene.geoms)
        geom_types = (enums.mjtObj.mjOBJ_BODY, enums.mjtObj.mjOBJ_GEOM,
                      enums.mjtObj.mjOBJ_SITE)
        for geom_type in geom_types:
            scene.ngeom += 1
            scene.ptr.contents.geoms[scene.ngeom - 1].objtype = geom_type
        self.assertLen(scene.geoms, len(geom_types))
        np.testing.assert_array_equal(scene.geoms.objtype, geom_types)

    def testInvalidFontScale(self):
        invalid_font_scale = 99
        with self.assertRaisesWithLiteralMatch(
                ValueError,
                core._INVALID_FONT_SCALE.format(invalid_font_scale)):
            core.MjrContext(
                model=self.model,
                gl_context=None,  # Don't need a context for this test.
                font_scale=invalid_font_scale)
def md5_hash(obj):
    pickled = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
    return md5(pickled).hexdigest()
Example #36
0
def test_gc_never_pickles_temporaries():
    x = T.dvector()

    for i in xrange(2):  # TODO: 30 causes like LONG compilation due to MERGE
        if i:
            r = r + r / 10
        else:
            r = x

    optimizer = None
    optimizer = 'fast_run'

    for f_linker, g_linker in [(theano.PerformLinker(allow_gc=True),
                                theano.PerformLinker(allow_gc=False)),
                               (theano.OpWiseCLinker(allow_gc=True),
                                theano.OpWiseCLinker(allow_gc=False))]:

        # f_linker has garbage collection

        # g_linker has no garbage collection

        f = theano.function([x],
                            r,
                            mode=theano.Mode(optimizer=optimizer,
                                             linker=f_linker))
        g = theano.function([x],
                            r,
                            mode=theano.Mode(optimizer=optimizer,
                                             linker=g_linker))

        len_pre_f = len(pickle.dumps(f))
        len_pre_g = len(pickle.dumps(g))

        # We can't compare the content or the length of the string
        # between f and g. 2 reason, we store some timming information
        # in float. They won't be the same each time. Different float
        # can have different lenght when printed.

        def a(fn):
            return len(pickle.dumps(fn.maker))

        assert a(f) == a(f)  # some sanity checks on the pickling mechanism
        assert a(g) == a(g)  # some sanity checks on the pickling mechanism

        def b(fn):
            return len(
                pickle.dumps(
                    theano.compile.function_module._pickle_Function(fn)))

        assert b(f) == b(f)  # some sanity checks on the pickling mechanism

        def c(fn):
            return len(pickle.dumps(fn))

        assert c(f) == c(f)  # some sanity checks on the pickling mechanism
        assert c(g) == c(g)  # some sanity checks on the pickling mechanism

        # now run the function once to create temporaries within the no-gc
        # linker
        f(numpy.ones(100, dtype='float64'))
        g(numpy.ones(100, dtype='float64'))

        # serialize the functions again
        post_f = pickle.dumps(f)
        post_g = pickle.dumps(g)
        len_post_f = len(post_f)
        len_post_g = len(post_g)

        # assert that f() didn't cause the function to grow
        # allow_gc should leave the function un-changed by calling
        assert len_pre_f == len_post_f

        # assert that g() didn't cause g to grow because temporaries
        # that weren't collected shouldn't be pickled anyway
        # Allow for a couple of bytes of difference, since timing info,
        # for instance, can be represented as text of varying size.
        assert abs(len_post_f - len_post_g) < 128, (f_linker, len_post_f,
                                                    len_post_g)
Example #37
0
 def b(fn):
     return len(
         pickle.dumps(
             theano.compile.function_module._pickle_Function(fn)))
def zpickle(data):
    """Given any data structure, returns a zlib compressed pickled serialization."""
    return zlib.compress(pickle.dumps(data, 4))  # Keep this constant as we upgrade from python 2 to 3.
Example #39
0
    def test_multiple_functions(self):
        a = T.scalar()  # the a is for 'anonymous' (un-named).
        x, s = T.scalars('xs')
        v = T.vector('v')

        # put in some inputs
        list_of_things = [s, x, v]

        # some derived thing, whose inputs aren't all in the list
        list_of_things.append(a * x + s)

        f1 = function([
            x,
            In(a, value=1.0, name='a'),
            In(s, value=0.0, update=s + a * x, mutable=True)
        ], s + a * x)
        list_of_things.append(f1)

        # now put in a function sharing container with the previous one
        f2 = function([
            x,
            In(a, value=1.0, name='a'),
            In(s, value=f1.container[s], update=s + a * x, mutable=True)
        ], s + a * x)
        list_of_things.append(f2)

        assert isinstance(f2.container[s].storage, list)
        assert f2.container[s].storage is f1.container[s].storage

        # now put in a function with non-scalar
        v_value = numpy.asarray([2, 3, 4.], dtype=config.floatX)
        f3 = function([x, In(v, value=v_value)], x + v)
        list_of_things.append(f3)

        # try to pickle the entire things
        try:
            saved_format = pickle.dumps(list_of_things, protocol=-1)
            new_list_of_things = pickle.loads(saved_format)
        except NotImplementedError as e:
            if e[0].startswith('DebugMode is not picklable'):
                return
            else:
                raise

        # now test our recovered new_list_of_things
        # it should be totally unrelated to the original
        # it should be interdependent in the same way as the original

        ol = list_of_things
        nl = new_list_of_things

        for i in range(4):
            assert nl[i] != ol[i]
            assert nl[i].type == ol[i].type
            assert nl[i].type is not ol[i].type

        # see if the implicit input got stored
        assert ol[3].owner.inputs[1] is s
        assert nl[3].owner.inputs[1] is not s
        assert nl[3].owner.inputs[1].type == s.type

        # moving on to the functions...
        for i in range(4, 7):
            assert nl[i] != ol[i]

        # looking at function number 1, input 's'
        assert nl[4][nl[0]] is not ol[4][ol[0]]
        assert nl[4][nl[0]] == ol[4][ol[0]]
        assert nl[4](3) == ol[4](3)

        # looking at function number 2, input 's'
        # make sure it's shared with the first function
        assert ol[4].container[ol[0]].storage is ol[5].container[ol[0]].storage
        assert nl[4].container[nl[0]].storage is nl[5].container[nl[0]].storage
        assert nl[5](3) == ol[5](3)
        assert nl[4].value[nl[0]] == 6

        assert numpy.all(nl[6][nl[2]] == numpy.asarray([2, 3., 4]))
Example #40
0
 def pickled(self):
     return cPickle.dumps(self)
Example #41
0
    def _build_index(self, descriptors):
        """
        Internal method to be implemented by sub-classes to build the index
        with the given descriptor data elements.

        Subsequent calls to this method should rebuild the current index.  This
        method shall not add to the existing index nor raise an exception to as
        to protect the current index.

        Implementation Notes:
            - We keep a cache file serialization around for our index in case
                sub-processing occurs so as to be able to recover from the
                underlying C data not being there. This could cause issues if
                a main or child process rebuild's the index, as we clear the
                old cache away.

        :param descriptors: Iterable of descriptor elements to build index
            over.
        :type descriptors:
            collections.Iterable[smqtk.representation.DescriptorElement]

        """
        with self._model_lock:
            # Not caring about restoring the index because we're just making a
            # new one.
            self._log.info("Building new FLANN index")

            self._log.debug("Caching descriptor elements")
            self._descr_cache = list(descriptors)
            # Cache descriptors if we have an element
            if self._descr_cache_elem and self._descr_cache_elem.writable():
                self._log.debug("Caching descriptors: %s",
                                self._descr_cache_elem)
                self._descr_cache_elem.set_bytes(
                    cPickle.dumps(self._descr_cache, -1))

            params = {
                "target_precision":
                self._build_target_precision,
                "sample_fraction":
                self._build_sample_frac,
                "log_level":
                ("info" if self._log.getEffectiveLevel() <= logging.DEBUG else
                 "warning")
            }
            if self._build_autotune:
                params['algorithm'] = "autotuned"
            if self._rand_seed is not None:
                params['random_seed'] = self._rand_seed
            pyflann.set_distance_type(self._distance_method)

            self._log.debug("Accumulating descriptor vectors into matrix for "
                            "FLANN")
            pts_array = elements_to_matrix(self._descr_cache,
                                           report_interval=1.0)

            self._log.debug('Building FLANN index')
            self._flann = pyflann.FLANN()
            self._flann_build_params = self._flann.build_index(
                pts_array, **params)
            del pts_array

            if self._index_elem and self._index_elem.writable():
                self._log.debug("Caching index: %s", self._index_elem)
                # FLANN wants to write to a file, so make a temp file, then
                # read it in, putting bytes into element.
                fd, fp = tempfile.mkstemp()
                try:
                    self._flann.save_index(fp)
                    # Use the file descriptor to create the file object.
                    # This avoids reopening the file and will automatically
                    # close the file descriptor on exiting the with block.
                    # fdopen() is required because in Python 2 open() does
                    # not accept a file descriptor.
                    with os.fdopen(fd, 'rb') as f:
                        self._index_elem.set_bytes(f.read())
                finally:
                    os.remove(fp)
            if self._index_param_elem and self._index_param_elem.writable():
                self._log.debug("Caching index params: %s",
                                self._index_param_elem)
                state = {
                    'b_autotune': self._build_autotune,
                    'b_target_precision': self._build_target_precision,
                    'b_sample_frac': self._build_sample_frac,
                    'distance_method': self._distance_method,
                    'flann_build_params': self._flann_build_params,
                }
                self._index_param_elem.set_bytes(cPickle.dumps(state, -1))

            self._pid = multiprocessing.current_process().pid
Example #42
0
def test_snk_pickle(fname):
    sn = csv_dist_sink(fname)
    ck = pickle.dumps(sn)
    pickle.loads(ck)
Example #43
0
def test_src_pickle(fname):
    sr = csv_dist_source(fname)
    ck = pickle.dumps(sr)
    pickle.loads(ck)
Example #44
0
 def testSerialization(self):
   desc = specs.DiscreteArray(2, np.int32, "test")
   self.assertEqual(pickle.loads(pickle.dumps(desc)), desc)
Example #45
0
 def c(fn):
     return len(pickle.dumps(fn))
Example #46
0
    def test_seeds_AAB(self):
        # launch 3 simultaneous experiments with seeds A, A, B.
        # Verify all experiments run to completion.
        # Verify first two experiments run identically.
        # Verify third experiment runs differently.

        exp_keys = ['A0', 'A1', 'B']
        seeds = [1, 1, 2]
        n_workers = 2
        jobs_per_thread = 6
        # -- total jobs = 2 * 6 = 12
        # -- divided by 3 experiments: 4 jobs per fmin
        max_evals = (n_workers * jobs_per_thread) // len(exp_keys)

        # -- should not matter which domain is used here
        domain = gauss_wave2()

        pickle.dumps(domain.expr)
        pickle.dumps(passthrough)

        worker_threads = [
            threading.Thread(target=TestExperimentWithThreads.worker_thread_fn,
                             args=(('hostname', ii), jobs_per_thread, 30.0))
            for ii in range(n_workers)
        ]

        with TempMongo() as tm:
            mj = tm.mongo_jobs('foodb')
            print(mj)
            trials_list = [
                MongoTrials(tm.connection_string('foodb'), key)
                for key in exp_keys
            ]

            fmin_threads = [
                threading.Thread(
                    target=TestExperimentWithThreads.fmin_thread_fn,
                    args=(domain.expr, trials, max_evals, seed))
                for seed, trials in zip(seeds, trials_list)
            ]

            try:
                [th.start() for th in worker_threads + fmin_threads]
            finally:
                print('joining worker threads...')
                [th.join() for th in worker_threads + fmin_threads]

            # -- not using an exp_key gives a handle to all the trials
            #    in foodb
            all_trials = MongoTrials(tm.connection_string('foodb'))
            self.assertEqual(len(all_trials), n_workers * jobs_per_thread)

            # Verify that the fmin calls terminated correctly:
            for trials in trials_list:
                self.assertEqual(trials.count_by_state_synced(JOB_STATE_DONE),
                                 max_evals)
                self.assertEqual(
                    trials.count_by_state_unsynced(JOB_STATE_DONE), max_evals)
                self.assertEqual(len(trials), max_evals)

            # Verify that the first two experiments match.
            # (Do these need sorting by trial id?)
            trials_A0, trials_A1, trials_B0 = trials_list
            self.assertEqual([t['misc']['vals'] for t in trials_A0.trials],
                             [t['misc']['vals'] for t in trials_A1.trials])

            # Verify that the last experiment does not match.
            # (Do these need sorting by trial id?)
            self.assertNotEqual([t['misc']['vals'] for t in trials_A0.trials],
                                [t['misc']['vals'] for t in trials_B0.trials])
Example #47
0
 def a(fn):
     return len(pickle.dumps(fn.maker))
Example #48
0
 def object_to_param_str(change):
     """Convert a change object into a format suitable for passing in job
     parameters
     """
     return b64encode(compress(cPickle.dumps(change))).decode('utf8')
Example #49
0
 def serialise_tracker(tracker):
     dialogue = tracker.as_dialogue()
     return pickler.dumps(dialogue)
Example #50
0
 def write(self, path):
     output_file = os.path.join(path, 'metadata')
     with atomic_file(output_file) as f:
         f.write(pickle.dumps(self, -1))
Example #51
0
 def dumps(self):
     return cPickle.dumps(self)
Example #52
0
 def test_data_stream_pickling(self):
     stream = DataStream(H5PYDataset(self.h5file, which_sets=('train', )),
                         iteration_scheme=SequentialScheme(100, 10))
     cPickle.loads(cPickle.dumps(stream))
     stream.close()
Example #53
0
 def add_numpy(d):
     metadata = {'dtype': d.dtype, 'shape': d.shape, 'datatype': 'numpy'}
     metadata = pickle.dumps(metadata)
     output.append(metadata)
     output.append(d)
Example #54
0
    def startup(self, recording_requester):
        """
        Prepare for a new run and create/update the abs2prom and prom2abs variables.

        Parameters
        ----------
        recording_requester : object
            Object to which this recorder is attached.
        """
        super(SqliteRecorder, self).startup(recording_requester)

        if not self._database_initialized:
            self._initialize_database()

        # grab the system
        if isinstance(recording_requester, Driver):
            system = recording_requester._problem.model
        elif isinstance(recording_requester, System):
            system = recording_requester
        elif isinstance(recording_requester, Problem):
            system = recording_requester.model
        else:
            system = recording_requester._system

        # grab all of the units and type (collective calls)
        states = system._list_states_allprocs()
        desvars = system.get_design_vars(True)
        responses = system.get_responses(True)
        objectives = system.get_objectives(True)
        constraints = system.get_constraints(True)
        inputs = system._var_allprocs_abs_names['input']
        outputs = system._var_allprocs_abs_names['output']
        full_var_set = [(inputs, 'input'), (outputs, 'output'),
                        (desvars, 'desvar'), (responses, 'response'),
                        (objectives, 'objective'), (constraints, 'constraint')]

        if self.connection:
            # merge current abs2prom and prom2abs with this system's version
            for io in ['input', 'output']:
                for v in system._var_abs2prom[io]:
                    self._abs2prom[io][v] = system._var_abs2prom[io][v]
                for v in system._var_allprocs_prom2abs_list[io]:
                    if v not in self._prom2abs[io]:
                        self._prom2abs[io][
                            v] = system._var_allprocs_prom2abs_list[io][v]
                    else:
                        self._prom2abs[io][v] = list(
                            set(self._prom2abs[io][v])
                            | set(system._var_allprocs_prom2abs_list[io][v]))

            for var_set, var_type in full_var_set:
                for name in var_set:
                    if name not in self._abs2meta:
                        self._abs2meta[name] = system._var_allprocs_abs2meta[
                            name].copy()
                        self._abs2meta[name]['type'] = set()
                        if name in states:
                            self._abs2meta[name]['explicit'] = False

                    if var_type not in self._abs2meta[name]['type']:
                        self._abs2meta[name]['type'].add(var_type)
                    self._abs2meta[name]['explicit'] = True

            for name in inputs:
                self._abs2meta[name] = system._var_allprocs_abs2meta[
                    name].copy()
                self._abs2meta[name]['type'] = set()
                self._abs2meta[name]['type'].add('input')
                self._abs2meta[name]['explicit'] = True
                if name in states:
                    self._abs2meta[name]['explicit'] = False

            # store the updated abs2prom and prom2abs
            abs2prom = pickle.dumps(self._abs2prom)
            prom2abs = pickle.dumps(self._prom2abs)
            abs2meta = pickle.dumps(self._abs2meta)

            with self.connection as c:
                c.execute(
                    "UPDATE metadata SET abs2prom=?, prom2abs=?, abs2meta=?",
                    (abs2prom, prom2abs, abs2meta))
Example #55
0
 def update_parameters(self, delta: list):
     request = urllib2.Request('http://{}/update'.format(self.master_url),
                               pickle.dumps(delta, -1),
                               headers=self.headers)
     return urllib2.urlopen(request).read()
Example #56
0
 def test_pickle_cpu(self):
     fs2_serialized = pickle.dumps(self.fs2)
     fs2_loaded = pickle.loads(fs2_serialized)
     self.assertTrue((self.fs2.b.p == fs2_loaded.b.p).all())
     self.assertTrue((self.fs2.fs1.a.p == fs2_loaded.fs1.a.p).all())
Example #57
0
def serialize(obj):
    """Serialize a Python object using pickle and encode it as an array of
    float32 values so that it can be feed into the workspace. See deserialize().
    """
    return np.fromstring(pickle.dumps(obj), dtype=np.uint8).astype(np.float32)
Example #58
0
    def make_db(self, data_name, output_dir='data/train', train_ratio=0.8):
        if data_name == 'train':
            div = 'train'
            data_path_list = opt.train_data_list
        elif data_name == 'dev':
            div = 'dev'
            data_path_list = opt.dev_data_list
        elif data_name == 'test':
            div = 'test'
            data_path_list = opt.test_data_list
        else:
            assert False, '%s is not valid data name' % data_name

        all_train = train_ratio >= 1.0
        all_dev = train_ratio == 0.0

        np.random.seed(17)
        self.logger.info('make database from data(%s) with train_ratio(%s)' %
                         (data_name, train_ratio))

        self.load_y_vocab()
        num_input_chunks = self._preprocessing(Data,
                                               data_path_list,
                                               div,
                                               chunk_size=opt.chunk_size)
        if not os.path.isdir(output_dir):
            os.makedirs(output_dir)

        data_fout = h5py.File(os.path.join(output_dir, 'data.h5py'), 'w')
        meta_fout = open(os.path.join(output_dir, 'meta'), 'wb')

        reader = Reader(data_path_list, div, None, None)
        tmp_size = reader.get_size()
        train_indices, train_size = self.get_train_indices(
            tmp_size, train_ratio)

        dev_size = tmp_size - train_size
        if all_dev:
            train_size = 1
            dev_size = tmp_size
        if all_train:
            dev_size = 1
            train_size = tmp_size

        train = data_fout.create_group('train')
        dev = data_fout.create_group('dev')
        self.create_dataset(train, train_size, len(self.y_vocab))
        self.create_dataset(dev, dev_size, len(self.y_vocab))
        self.logger.info('train_size ~ %s, dev_size ~ %s' %
                         (train_size, dev_size))

        sample_idx = 0
        dataset = {'train': train, 'dev': dev}
        num_samples = {'train': 0, 'dev': 0}
        chunk_size = opt.db_chunk_size
        chunk = {
            'train': self.init_chunk(chunk_size, len(self.y_vocab)),
            'dev': self.init_chunk(chunk_size, len(self.y_vocab))
        }
        chunk_order = list(range(num_input_chunks))
        np.random.shuffle(chunk_order)
        for input_chunk_idx in chunk_order:
            path = os.path.join(self.tmp_chunk_tpl % input_chunk_idx)
            self.logger.info('processing %s ...' % path)
            data = list(enumerate(cPickle.loads(open(path, 'rb').read())))
            np.random.shuffle(data)
            for data_idx, (pid, y, vw) in data:
                if y is None:
                    continue
                v, w = vw
                is_train = train_indices[sample_idx + data_idx]
                if all_dev:
                    is_train = False
                if all_train:
                    is_train = True
                if v is None:
                    continue
                c = chunk['train'] if is_train else chunk['dev']
                idx = c['num']
                c['uni'][idx] = v
                c['w_uni'][idx] = w
                c['cate'][idx] = y
                c['num'] += 1
                if not is_train:
                    c['pid'].append(np.string_(pid))
                for t in ['train', 'dev']:
                    if chunk[t]['num'] >= chunk_size:
                        self.copy_chunk(dataset[t],
                                        chunk[t],
                                        num_samples[t],
                                        with_pid_field=t == 'dev')
                        num_samples[t] += chunk[t]['num']
                        chunk[t] = self.init_chunk(chunk_size,
                                                   len(self.y_vocab))
            sample_idx += len(data)
        for t in ['train', 'dev']:
            if chunk[t]['num'] > 0:
                self.copy_chunk(dataset[t],
                                chunk[t],
                                num_samples[t],
                                with_pid_field=t == 'dev')
                num_samples[t] += chunk[t]['num']

        for div in ['train', 'dev']:
            ds = dataset[div]
            size = num_samples[div]
            shape = (size, opt.max_len)
            ds['uni'].resize(shape)
            ds['w_uni'].resize(shape)
            ds['cate'].resize((size, len(self.y_vocab)))

        data_fout.close()
        meta = {'y_vocab': self.y_vocab}
        meta_fout.write(cPickle.dumps(meta, 2))
        meta_fout.close()

        self.logger.info('# of classes: %s' % len(meta['y_vocab']))
        self.logger.info('# of samples on train: %s' % num_samples['train'])
        self.logger.info('# of samples on dev: %s' % num_samples['dev'])
        self.logger.info('data: %s' % os.path.join(output_dir, 'data.h5py'))
        self.logger.info('meta: %s' % os.path.join(output_dir, 'meta'))
Example #59
0
 def test_pickle_cpu(self):
     s = pickle.dumps(self.fs)
     fs2 = pickle.loads(s)
     self.check_equal_fs(self.fs, fs2)
Example #60
0
 def test_unpack_picklable(self):
     wrapper = Unpack(self.stream_np)
     epoch = wrapper.get_epoch_iterator()
     cPickle.dumps(epoch)