Ejemplo n.º 1
0
    def test_transaction_rollback(self):
        """Test transaction cancellation using a with statement"""
        tmp = np.array(
            [(1, "Peter", "mouse"), (2, "Klara", "monkey")],
            self.relation.heading.as_dtype,
        )

        self.relation.delete()
        with self.conn.transaction:
            self.relation.insert1(tmp[0])
        try:
            with self.conn.transaction:
                self.relation.insert1(tmp[1])
                raise DataJointError("Testing rollback")
        except DataJointError:
            pass
        assert_equal(
            len(self.relation),
            1,
            "Length is not 1. Expected because rollback should have happened.",
        )
        assert_equal(
            len(self.relation & "subject_id = 2"),
            0,
            "Length is not 0. Expected because rollback should have happened.",
        )
Ejemplo n.º 2
0
    def __getitem__(self, item):
        """
        Fetch attributes as separate outputs.
        datajoint.key is a special value that requests the entire primary key

        :return: tuple with an entry for each element of item

        Examples:

        >>> a, b = relation['a', 'b']
        >>> a, b, key = relation['a', 'b', datajoint.key]
        >>> results = relation['a':'z']    # return attributes a-z as a tuple
        >>> results = relation[:-1]   # return all but the last attribute

        """
        single_output = isinstance(
            item, str) or item is PRIMARY_KEY or isinstance(item, int)
        item, attributes = prepare_attributes(self._relation, item)

        result = self._relation.project(*attributes).fetch()
        if len(result) != 1:
            raise DataJointError('Fetch1 should only return one tuple')

        return_values = tuple(
            np.ndarray(
                result.shape,
                np.dtype({
                    name: result.dtype.fields[name]
                    for name in self._relation.primary_key
                }), result, 0, result.strides
            ) if attribute is PRIMARY_KEY else result[attribute][0]
            for attribute in item)
        return return_values[0] if single_output else return_values
Ejemplo n.º 3
0
def prepare_attributes(relation, item):
    """
    Used by fetch.__getitem__ to deal with slices

    :param relation: the relation that created the fetch object
    :param item: the item passed to __getitem__. Can be a string, a tuple, a list, or a slice.

    :return: a tuple of items to fetch, a list of the corresponding attributes
    :raise DataJointError: if item does not match one of the datatypes above
    """
    if isinstance(item, str) or item is PRIMARY_KEY:
        item = (item, )
    elif isinstance(item, int):
        item = (relation.heading.names[item], )
    elif isinstance(item, slice):
        attributes = relation.heading.names
        start = attributes.index(item.start) if isinstance(item.start,
                                                           str) else item.start
        stop = attributes.index(item.stop) if isinstance(item.stop,
                                                         str) else item.stop
        item = attributes[slice(start, stop, item.step)]
    try:
        attributes = tuple(i for i in item if i is not PRIMARY_KEY)
    except TypeError:
        raise DataJointError(
            "Index must be a slice, a tuple, a list, a string.")
    return item, attributes
Ejemplo n.º 4
0
    def __call__(self):
        """
        This version of fetch is called when self is expected to contain exactly one tuple.

        :return: the one tuple in the relation in the form of a dict
        """
        heading = self._relation.heading

        cur = self._relation.cursor(as_dict=True)
        ret = cur.fetchone()
        if not ret or cur.fetchone():
            raise DataJointError(
                'fetch1 should only be used for relations with exactly one tuple'
            )

        return OrderedDict(
            (name, unpack(ret[name]) if heading[name].is_blob else ret[name])
            for name in heading.names)
Ejemplo n.º 5
0
def from_camel_case(s):
    """
    Convert names in camel case into underscore (_) separated names

    :param s: string in CamelCase notation
    :returns: string in under_score notation

    Example:

    >>> from_camel_case("TableName") # yields "table_name"
    """

    def convert(match):
        return ('_' if match.groups()[0] else '') + match.group(0).lower()

    if not re.match(r'[A-Z][a-zA-Z0-9]*', s):
        raise DataJointError(
            'ClassName must be alphanumeric in CamelCase, begin with a capital letter')
    return re.sub(r'(\B[A-Z])|(\b[A-Z])', convert, s)
Ejemplo n.º 6
0
        def log_key(self, key):
            path = inspect.getabsfile(cls).split('/')

            for i in reversed(range(len(path))):
                tmp_path = '/'.join(path[:i])
                if os.path.exists(tmp_path + '/.git'):
                    repo = git.Repo(tmp_path)
                    break
            else:
                raise DataJointError("%s.GitKey could not find a .git directory for %s" % (cls.__name__, cls.__name__))
            sha1, branch = repo.head.commit.name_rev.split()
            modified = (repo.git.status().find("modified") > 0) * 1
            if modified:
                if not cls._raise_error_on_modified and not _FAIL_ON_ERROR:
                    warnings.warn(
                        'You have uncommited changes. Consider committing the changes before running populate.')
                else:
                    raise PedanticError('You have uncommited changes. Commit changes before running populate!')
            key['sha1'] = sha1
            key['branch'] = branch
            key['modified'] = modified
            key['head_date'] = datetime.datetime.fromtimestamp(repo.head.commit.authored_date)
            self.insert1(key, skip_duplicates=True)