예제 #1
0
def test_class_sigs():
    """Test that all ``cdef class`` extension types in ``cytoolz`` have
    correctly embedded the function signature as done in ``toolz``.
    """
    import toolz

    # only consider items created in both `toolz` and `cytoolz`
    toolz_dict = valfilter(isfrommod("toolz"), toolz.__dict__)
    cytoolz_dict = valfilter(isfrommod("cytoolz"), cytoolz.__dict__)

    # only test `cdef class` extensions from `cytoolz`
    cytoolz_dict = valfilter(
        lambda x: not isinstance(x, BuiltinFunctionType), cytoolz_dict
    )

    # full API coverage should be tested elsewhere
    toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict)
    cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict)

    d = merge_with(identity, toolz_dict, cytoolz_dict)
    for key, (toolz_func, cytoolz_func) in d.items():
        if key in ["excepts", "juxt", "memoize", "flip"]:
            continue
        try:
            # function
            toolz_spec = inspect.getargspec(toolz_func)
        except TypeError:
            try:
                # curried or partial object
                toolz_spec = inspect.getargspec(toolz_func.func)
            except (TypeError, AttributeError):
                # class
                toolz_spec = inspect.getargspec(toolz_func.__init__)

        # For Cython < 0.25
        toolz_sig = toolz_func.__name__ + inspect.formatargspec(*toolz_spec)
        doc = cytoolz_func.__doc__
        # For Cython >= 0.25
        toolz_sig_alt = toolz_func.__name__ + inspect.formatargspec(
            *toolz_spec,
            **{"formatvalue": lambda x: "=" + getattr(x, "__name__", repr(x))}
        )
        doc_alt = doc.replace("Py_ssize_t ", "")
        if not (toolz_sig in doc or toolz_sig_alt in doc_alt):
            message = (
                "cytoolz.%s does not have correct function signature."
                "\n\nExpected: %s"
                "\n\nDocstring in cytoolz is:\n%s"
                % (key, toolz_sig, cytoolz_func.__doc__)
            )
            assert False, message
예제 #2
0
def test_sig_at_beginning():
    """Test that the function signature is at the beginning of the docstring
    and is followed by exactly one blank line.
    """
    cytoolz_dict = valfilter(isfrommod("cytoolz"), cytoolz.__dict__)
    cytoolz_dict = keyfilter(lambda x: x not in skip_sigs, cytoolz_dict)

    for key, val in cytoolz_dict.items():
        doclines = val.__doc__.splitlines()
        assert len(doclines) > 2, "cytoolz.%s docstring too short:\n\n%s" % (
            key,
            val.__doc__,
        )

        sig = "%s(" % aliases.get(key, key)
        assert sig in doclines[0], (
            "cytoolz.%s docstring missing signature at beginning:\n\n%s"
            % (key, val.__doc__)
        )

        assert not doclines[1], (
            "cytoolz.%s docstring missing blank line after signature:\n\n%s"
            % (key, val.__doc__)
        )

        assert doclines[2], (
            "cytoolz.%s docstring too many blank lines after signature:\n\n%s"
            % (key, val.__doc__)
        )
예제 #3
0
    def wordclouds(self, df, stage=0, features="frequency", name=None, stopwords=None):
        """
        Build wordclouds and choose what features to use
        :param df:
        :param stage:
        :param features:
        :param name:
        :param stopwords:
        :return:
        """
        # If only using frequency, use a pure Python method
        if features == "frequency":
            vocab = self._get_wordcloud_frequency_vocab(df, stage)
        # If using TF-IDF, use pre-fit vectorizer
        else:
            # features == "tfidf":
            vocab = self._get_wordcloud_tfidf_vocab(df)

        # Remove defined stopwords
        if stopwords is not None:
            vocab = ct.keyfilter(lambda x: x not in stopwords, vocab)

        # Pass pre-made frequencies to wordcloud, allowing for TF-IDF
        self.wordcloud.generate_from_frequencies(frequencies=vocab)

        # Prepare plot with title, etc
        self._plot_wordcloud(stage, name)
예제 #4
0
파일: genesig.py 프로젝트: verohan/pySCENIC
 def head(self, n: int = 5) -> Type['GeneSignature']:
     """
     Returns a gene signature with only the top n targets.
     """
     assert n >= 1, "n must be greater than or equal to one."
     genes = self.genes[0:n] # Genes are sorted in ascending order according to weight.
     return self.copy(gene2weight=keyfilter(lambda k: k in genes, self.gene2weight))
예제 #5
0
def test_class_sigs():
    """ Test that all ``cdef class`` extension types in ``cytoolz`` have
        correctly embedded the function signature as done in ``toolz``.
    """
    import toolz
    # only consider items created in both `toolz` and `cytoolz`
    toolz_dict = valfilter(isfrommod('toolz'), toolz.__dict__)
    cytoolz_dict = valfilter(isfrommod('cytoolz'), cytoolz.__dict__)

    # only test `cdef class` extensions from `cytoolz`
    cytoolz_dict = valfilter(lambda x: not isinstance(x, BuiltinFunctionType),
                             cytoolz_dict)

    # full API coverage should be tested elsewhere
    toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict)
    cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict)

    d = merge_with(identity, toolz_dict, cytoolz_dict)
    for key, (toolz_func, cytoolz_func) in d.items():
        if key in ['excepts', 'juxt', 'memoize', 'flip']:
            continue
        try:
            # function
            toolz_spec = inspect.getargspec(toolz_func)
        except TypeError:
            try:
                # curried or partial object
                toolz_spec = inspect.getargspec(toolz_func.func)
            except (TypeError, AttributeError):
                # class
                toolz_spec = inspect.getargspec(toolz_func.__init__)

        # For Cython < 0.25
        toolz_sig = toolz_func.__name__ + inspect.formatargspec(*toolz_spec)
        doc = cytoolz_func.__doc__
        # For Cython >= 0.25
        toolz_sig_alt = toolz_func.__name__ + inspect.formatargspec(
            *toolz_spec,
            **{'formatvalue': lambda x: '=' + getattr(x, '__name__', repr(x))}
        )
        doc_alt = doc.replace('Py_ssize_t ', '')
        if not (toolz_sig in doc or toolz_sig_alt in doc_alt):
            message = ('cytoolz.%s does not have correct function signature.'
                       '\n\nExpected: %s'
                       '\n\nDocstring in cytoolz is:\n%s'
                       % (key, toolz_sig, cytoolz_func.__doc__))
            assert False, message
예제 #6
0
def clean_dialect(dialect):
    """ Make a csv dialect apprpriate for pandas.read_csv """
    dialect = keyfilter(read_csv_kwargs.__contains__, dialect)
    # handle windows
    if dialect['lineterminator'] == '\r\n':
        dialect['lineterminator'] = None

    return dialect
예제 #7
0
파일: csv.py 프로젝트: Casolt/blaze
def clean_dialect(dialect):
    """ Make a csv dialect apprpriate for pandas.read_csv """
    dialect = keyfilter(read_csv_kwargs.__contains__,
            dialect)
    # handle windows
    if dialect['lineterminator'] == '\r\n':
        dialect['lineterminator'] = None

    return dialect
예제 #8
0
def test_docstrings_uptodate():
    import toolz

    differ = difflib.Differ()

    # only consider items created in both `toolz` and `cytoolz`
    toolz_dict = valfilter(isfrommod("toolz"), toolz.__dict__)
    cytoolz_dict = valfilter(isfrommod("cytoolz"), cytoolz.__dict__)

    # only test functions that have docstrings defined in `toolz`
    toolz_dict = valfilter(lambda x: getattr(x, "__doc__", ""), toolz_dict)

    # full API coverage should be tested elsewhere
    toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict)
    cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict)

    d = merge_with(identity, toolz_dict, cytoolz_dict)
    for key, (toolz_func, cytoolz_func) in d.items():
        # only check if the new doctstring *contains* the expected docstring
        toolz_doc = convertdoc(toolz_func)
        cytoolz_doc = cytoolz_func.__doc__
        if toolz_doc not in cytoolz_doc:
            diff = list(
                differ.compare(toolz_doc.splitlines(),
                               cytoolz_doc.splitlines()))
            fulldiff = list(diff)
            # remove additional lines at the beginning
            while diff and diff[0].startswith("+"):
                diff.pop(0)
            # remove additional lines at the end
            while diff and diff[-1].startswith("+"):
                diff.pop()

            def checkbad(line):
                return line.startswith("+") and not (
                    "# doctest: +SKIP" in line and key in skipped_doctests)

            if any(map(checkbad, diff)):
                assert False, "Error: cytoolz.%s has a bad docstring:\n%s\n" % (
                    key,
                    "\n".join(fulldiff),
                )
예제 #9
0
파일: strtools.py 프로젝트: carnby/aguaite
def clean_id_strs(tweet):
    if type(tweet) == dict:
        tweet = keyfilter(lambda x: not x.endswith('id_str'), tweet)

        for key in tweet.keys():
            tweet[key] = clean_id_strs(tweet[key])

    elif type(tweet) == list:
        tweet = map(clean_id_strs, tweet)

    return tweet
예제 #10
0
def test_curried_namespace():
    def should_curry(func):
        if not callable(func) or isinstance(func, curry):
            return False
        nargs = enhanced_num_required_args(func)
        if nargs is None or nargs > 1:
            return True
        else:
            return nargs == 1 and enhanced_has_keywords(func)

    def curry_namespace(ns):
        return dict((
            name,
            curry(f) if should_curry(f) else f,
        ) for name, f in ns.items() if '__' not in name)

    all_auto_curried = curry_namespace(vars(eth_utils))

    inferred_namespace = valfilter(callable, all_auto_curried)
    curried_namespace = valfilter(callable, eth_utils.curried.__dict__)

    if inferred_namespace != curried_namespace:
        missing = set(inferred_namespace) - set(curried_namespace)
        if missing:
            to_insert = sorted("%s," % f for f in missing)
            raise AssertionError(
                'There are missing functions in eth_utils.curried:\n' +
                '\n'.join(to_insert))
        extra = set(curried_namespace) - set(inferred_namespace)
        if extra:
            raise AssertionError(
                'There are extra functions in eth_utils.curried:\n' +
                '\n'.join(sorted(extra)))
        unequal = merge_with(list, inferred_namespace, curried_namespace)
        unequal = valfilter(lambda x: x[0] != x[1], unequal)
        to_curry = keyfilter(lambda x: should_curry(getattr(eth_utils, x)),
                             unequal)
        if to_curry:
            to_curry_formatted = sorted('{0} = curry({0})'.format(f)
                                        for f in to_curry)
            raise AssertionError(
                'There are missing functions to curry in eth_utils.curried:\n'
                + '\n'.join(to_curry_formatted))
        elif unequal:
            not_to_curry_formatted = sorted(unequal)
            raise AssertionError(
                'Missing functions NOT to curry in eth_utils.curried:\n' +
                '\n'.join(not_to_curry_formatted))
        else:
            raise AssertionError("unexplained difference between %r and %r" % (
                inferred_namespace,
                curried_namespace,
            ))
예제 #11
0
def test_docstrings_uptodate():
    import toolz
    differ = difflib.Differ()

    # only consider items created in both `toolz` and `cytoolz`
    toolz_dict = valfilter(isfrommod('toolz'), toolz.__dict__)
    cytoolz_dict = valfilter(isfrommod('cytoolz'), cytoolz.__dict__)

    # only test functions that have docstrings defined in `toolz`
    toolz_dict = valfilter(lambda x: getattr(x, '__doc__', ''), toolz_dict)

    # full API coverage should be tested elsewhere
    toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict)
    cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict)

    d = merge_with(identity, toolz_dict, cytoolz_dict)
    for key, (toolz_func, cytoolz_func) in d.items():
        # only check if the new doctstring *contains* the expected docstring
        toolz_doc = convertdoc(toolz_func)
        cytoolz_doc = cytoolz_func.__doc__
        if toolz_doc not in cytoolz_doc:
            diff = list(differ.compare(toolz_doc.splitlines(),
                                       cytoolz_doc.splitlines()))
            fulldiff = list(diff)
            # remove additional lines at the beginning
            while diff and diff[0].startswith('+'):
                diff.pop(0)
            # remove additional lines at the end
            while diff and diff[-1].startswith('+'):
                diff.pop()

            def checkbad(line):
                return (line.startswith('+') and
                        not ('# doctest: +SKIP' in line and
                             key in skipped_doctests))

            if any(map(checkbad, diff)):
                assert False, 'Error: cytoolz.%s has a bad docstring:\n%s\n' % (
                    key, '\n'.join(fulldiff))
예제 #12
0
파일: csv.py 프로젝트: ChrisBeaumont/blaze
    def _get_reader(self, header=None, keep_default_na=False,
                    na_values=na_values, chunksize=None, **kwargs):
        kwargs.setdefault('skiprows', int(bool(self.header)))

        dialect = merge(keyfilter(read_csv_kwargs.__contains__, self.dialect),
                        kwargs)

        # handle windows
        if dialect['lineterminator'] == '\r\n':
            dialect['lineterminator'] = None
        return partial(pd.read_csv, chunksize=chunksize, na_values=na_values,
                       keep_default_na=keep_default_na, encoding=self.encoding,
                       header=header, **dialect)
예제 #13
0
def test_class_sigs():
    """ Test that all ``cdef class`` extension types in ``cytoolz`` have
        correctly embedded the function signature as done in ``toolz``.
    """
    import toolz
    # only consider items created in both `toolz` and `cytoolz`
    toolz_dict = valfilter(isfrommod('toolz'), toolz.__dict__)
    cytoolz_dict = valfilter(isfrommod('cytoolz'), cytoolz.__dict__)

    # only test `cdef class` extensions from `cytoolz`
    cytoolz_dict = valfilter(lambda x: not isinstance(x, BuiltinFunctionType),
                             cytoolz_dict)

    # full API coverage should be tested elsewhere
    toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict)
    cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict)

    d = merge_with(identity, toolz_dict, cytoolz_dict)
    for key, (toolz_func, cytoolz_func) in d.items():
        if key in ['excepts', 'juxt']:
            continue
        try:
            # function
            toolz_spec = inspect.getargspec(toolz_func)
        except TypeError:
            try:
                # curried or partial object
                toolz_spec = inspect.getargspec(toolz_func.func)
            except (TypeError, AttributeError):
                # class
                toolz_spec = inspect.getargspec(toolz_func.__init__)

        toolz_sig = toolz_func.__name__ + inspect.formatargspec(*toolz_spec)
        if toolz_sig not in cytoolz_func.__doc__:
            message = ('cytoolz.%s does not have correct function signature.'
                       '\n\nExpected: %s'
                       '\n\nDocstring in cytoolz is:\n%s'
                       % (key, toolz_sig, cytoolz_func.__doc__))
            assert False, message
예제 #14
0
파일: carve.py 프로젝트: jondot/carve
def treemap(tree, mapfn, key=None, path=()):
    res = tree
    if isinstance(tree, dict):
        res = keyfilter(
            identity,
            itemmap(
                lambda item: treemap(item[1], mapfn, item[0], path +
                                     (item[0], )), tree),
        )

    elif isinstance(tree, list):
        res = list(map(lambda t: treemap(t, mapfn, None, path), tree))

    return xform(key, res, mapfn, path)
예제 #15
0
파일: genesig.py 프로젝트: verohan/pySCENIC
    def intersection(self, other: Type['GeneSignature']) -> Type['GeneSignature']:
        """
        Creates a new :class:`GeneSignature` instance which is the intersection of this signature and the supplied other
        signature.

        The weight associated with the genes in the intersection is the maximum of the weights in the composing signatures.

        :param other: The other :class:`GeneSignature`.
        :return: the new :class:`GeneSignature` instance.
        """
        genes = set(self.gene2weight.keys()).intersection(set(other.gene2weight.keys()))
        return self.copy(name="({} & {})".format(self.name, other.name) if self.name != other.name else self.name,
                         gene2weight=frozendict(keyfilter(lambda k: k in genes,
                                                          merge_with(max, self.gene2weight, other.gene2weight))))
예제 #16
0
파일: csv.py 프로젝트: Casolt/blaze
    def _extend(self, rows):
        mode = 'ab' if PY2 else 'a'
        newline = dict() if PY2 else dict(newline='')
        dialect = keyfilter(to_csv_kwargs.__contains__, self.dialect)
        should_write_newline = self.last_char() != os.linesep
        with csvopen(self, mode=mode, **newline) as f:
            # we have data in the file, append a newline
            if should_write_newline:
                f.write(os.linesep)

            for df in map(partial(bz.into, pd.DataFrame),
                          partition_all(self.chunksize, iter(rows))):
                df.to_csv(f, index=False, header=None, encoding=self.encoding,
                          **dialect)
예제 #17
0
def test_class_sigs():
    """ Test that all ``cdef class`` extension types in ``cytoolz`` have
        correctly embedded the function signature as done in ``toolz``.
    """
    import toolz
    # only consider items created in both `toolz` and `cytoolz`
    toolz_dict = valfilter(isfrommod('toolz'), toolz.__dict__)
    cytoolz_dict = valfilter(isfrommod('cytoolz'), cytoolz.__dict__)

    # only test `cdef class` extensions from `cytoolz`
    cytoolz_dict = valfilter(lambda x: not isinstance(x, BuiltinFunctionType),
                             cytoolz_dict)

    # full API coverage should be tested elsewhere
    toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict)
    cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict)

    d = merge_with(identity, toolz_dict, cytoolz_dict)
    for key, (toolz_func, cytoolz_func) in d.items():
        try:
            # function
            toolz_spec = inspect.getargspec(toolz_func)
        except TypeError:
            try:
                # curried or partial object
                toolz_spec = inspect.getargspec(toolz_func.func)
            except (TypeError, AttributeError):
                # class
                toolz_spec = inspect.getargspec(toolz_func.__init__)

        toolz_sig = toolz_func.__name__ + inspect.formatargspec(*toolz_spec)
        if toolz_sig not in cytoolz_func.__doc__:
            message = ('cytoolz.%s does not have correct function signature.'
                       '\n\nExpected: %s'
                       '\n\nDocstring in cytoolz is:\n%s'
                       % (key, toolz_sig, cytoolz_func.__doc__))
            assert False, message
예제 #18
0
def run(cfg: HyperParamCfg):
    logger = hyperparam_logger
    train_cfg, params = make_dict(
        toolz.keyfilter(
            lambda x: x in TrainConfig.__annotations__.keys(), cfg._asdict()
        ),
        {},
    )
    solver = cfg.solver_class(params)

    while True:
        assigement = solver.new_assignment()
        concrete_params = assigement.params
        folder = os.path.normpath(
            os.path.abspath(os.path.join(cfg.work_dir, assigement.name))
        )
        logger.info(f"Starting {assigement.name}")
        cfg_path = os.path.join(folder, "config.yml")
        os.makedirs(folder, exist_ok=False)

        concrete_cfg = subs_dict(train_cfg, concrete_params)
        concrete_cfg['logdir'] = folder
        concrete_cfg = subs_dict(TrainConfig.schema(concrete_cfg), {})
        with open(cfg_path, "w") as f:
            yaml.safe_dump({
                'train': concrete_cfg,
                'version': "v0.1",
            },
                           stream=f,
                           default_flow_style=False)
        result = train.run_args(
            argparse.Namespace(
                config=cfg_path,
                logdir=None,
                name=assigement.name,
            )
        )
        logger.info(f"Got results:\n{result.describe().to_string()}\n{result}")
        obs = Observation(
            metric=float(np.mean(result['identity'])),
            metric_std=float(np.std(result['identity'])),
            metadata={
                c: float(np.mean(series))
                for c, series in result.iteritems()
            }
        )
        solver.report(assigement, obs)
예제 #19
0
파일: csv.py 프로젝트: pgnepal/blaze
    def reader(self, header=None, keep_default_na=False,
               na_values=na_values, chunksize=None, **kwargs):
        kwargs.setdefault('skiprows', int(bool(self.header)))

        dialect = merge(keyfilter(read_csv_kwargs.__contains__, self.dialect),
                        kwargs)
        filename, ext = os.path.splitext(self.path)
        ext = ext.lstrip('.')
        # handle windows
        if dialect['lineterminator'] == '\r\n':
            dialect['lineterminator'] = None
        reader = pd.read_csv(self.path, compression={'gz': 'gzip',
                                                     'bz2': 'bz2'}.get(ext),
                             chunksize=chunksize, na_values=na_values,
                             keep_default_na=keep_default_na,
                             encoding=self.encoding, header=header, **dialect)
        return reader
예제 #20
0
    def _extend(self, rows):
        mode = 'ab' if PY2 else 'a'
        newline = dict() if PY2 else dict(newline='')
        dialect = keyfilter(to_csv_kwargs.__contains__, self.dialect)
        should_write_newline = self.last_char() != os.linesep
        with csvopen(self, mode=mode, **newline) as f:
            # we have data in the file, append a newline
            if should_write_newline:
                f.write(os.linesep)

            for df in map(partial(bz.into, pd.DataFrame),
                          partition_all(self.chunksize, iter(rows))):
                df.to_csv(f,
                          index=False,
                          header=None,
                          encoding=self.encoding,
                          **dialect)
예제 #21
0
    def apply(
        self,
        batch: DataPanel,
        columns: List[str],
        skeleton_batches: List[DataPanel],
        slice_membership: np.ndarray,
        *args,
        **kwargs,
    ) -> Tuple[List[DataPanel], np.ndarray]:

        # Group the batch into inputs and output
        batch_inputs = tz.keyfilter(lambda k: k in columns[:-1], batch)
        batch_inputs = [
            OrderedDict(zip(batch_inputs, t))
            for t in zip(*batch_inputs.values())
        ]

        batch_output = [int(e) for e in batch[columns[-1]]]

        # Create a fake dataset for textattack
        fake_dataset = list(zip(batch_inputs, batch_output))

        # Attack the dataset
        outputs = list(self.attack.attack_dataset(fake_dataset))

        for i, output in enumerate(outputs):
            # Check if the goal succeeded
            if output.perturbed_result.goal_status == 0:
                # If success, fill out the skeleton batch
                for (
                        key,
                        val,
                ) in output.perturbed_result.attacked_text._text_input.items():
                    # TODO(karan): support num_attacked_texts > 1
                    skeleton_batches[0][key][i] = val

                # # Fill the perturbed output: *this was incorrect, removing this
                # statement*
                # # TODO(karan): delete this snippet
                # skeleton_batches[0][columns[-1]][i] = output.perturbed_result.output
            else:
                # Unable to attack the example: set its slice membership to zero
                slice_membership[i, 0] = 0

        return skeleton_batches, slice_membership
예제 #22
0
파일: csv.py 프로젝트: pgnepal/blaze
    def _extend(self, rows):
        mode = 'ab' if PY2 else 'a'
        newline = dict() if PY2 else dict(newline='')
        dialect = keyfilter(to_csv_kwargs.__contains__, self.dialect)
        should_write_newline = self.last_char() != os.linesep
        f = self.open(self.path, mode, **newline)

        try:
            # we have data in the file, append a newline
            if should_write_newline:
                f.write(os.linesep)

            for df in map(partial(bz.into, pd.DataFrame),
                          partition_all(self.chunksize, iter(rows))):
                df.to_csv(f, index=False, header=None, **dialect)
        finally:
            try:
                f.close()
            except AttributeError:
                pass
예제 #23
0
def test_sig_at_beginning():
    """ Test that the function signature is at the beginning of the docstring
        and is followed by exactly one blank line.
    """
    cytoolz_dict = valfilter(isfrommod('cytoolz'), cytoolz.__dict__)
    cytoolz_dict = keyfilter(lambda x: x not in skip_sigs, cytoolz_dict)

    for key, val in cytoolz_dict.items():
        doclines = val.__doc__.splitlines()
        assert len(doclines) > 2, (
            'cytoolz.%s docstring too short:\n\n%s' % (key, val.__doc__))

        sig = '%s(' % aliases.get(key, key)
        assert sig in doclines[0], (
            'cytoolz.%s docstring missing signature at beginning:\n\n%s'
            % (key, val.__doc__))

        assert not doclines[1], (
            'cytoolz.%s docstring missing blank line after signature:\n\n%s'
            % (key, val.__doc__))

        assert doclines[2], (
            'cytoolz.%s docstring too many blank lines after signature:\n\n%s'
            % (key, val.__doc__))
예제 #24
0
 def select_columns(self, columns: List[str]) -> Batch:
     """Select a subset of columns."""
     for col in columns:
         assert col in self._data
     return tz.keyfilter(lambda k: k in columns, self._data)
예제 #25
0
파일: csv.py 프로젝트: ChrisBeaumont/blaze
    def __init__(self, path, mode='rt', schema=None, columns=None, types=None,
                 typehints=None, dialect=None, header=None, open=open,
                 nrows_discovery=50, chunksize=1024,
                 encoding=sys.getdefaultencoding(), **kwargs):
        if 'r' in mode and not os.path.isfile(path):
            raise ValueError('CSV file "%s" does not exist' % path)

        if schema is None and 'w' in mode:
            raise ValueError('Please specify schema for writable CSV file')

        self.path = path
        self.mode = mode
        self.open = {'gz': gzip.open, 'bz2': bz2.BZ2File}.get(ext(path), open)
        self.header = header
        self._abspath = os.path.abspath(path)
        self.chunksize = chunksize
        self.encoding = encoding

        sample = get_sample(self)
        self.dialect = dialect = discover_dialect(sample, dialect, **kwargs)

        if header is None:
            header = has_header(sample, encoding=encoding)
        elif isinstance(header, int):
            dialect['header'] = header
            header = True

        reader_dialect = keyfilter(read_csv_kwargs.__contains__, dialect)
        if not schema and 'w' not in mode:
            if not types:
                data = self._reader(skiprows=1 if header else 0,
                                    nrows=nrows_discovery, as_recarray=True,
                                    index_col=False, header=0 if header else
                                    None, **reader_dialect).tolist()
                types = discover(data)
                rowtype = types.subshape[0]
                if isinstance(rowtype[0], Tuple):
                    types = types.subshape[0][0].dshapes
                    types = [unpack(t) for t in types]
                    types = [string if t == null else t for t in types]
                    types = [t if isinstance(t, Option) or t == string
                             else Option(t) for t in types]
                elif (isinstance(rowtype[0], Fixed) and
                        isinstance(rowtype[1], CType)):
                    types = int(rowtype[0]) * [rowtype[1]]
                else:
                    raise ValueError("Could not discover schema from data.\n"
                                     "Please specify schema.")
            if not columns:
                if header:
                    columns = first(self._reader(skiprows=0, nrows=1,
                                                 header=None, **reader_dialect
                                                 ).itertuples(index=False))
                else:
                    columns = ['_%d' % i for i in range(len(types))]
            if typehints:
                types = [typehints.get(c, t) for c, t in zip(columns, types)]

            schema = dshape(Record(list(zip(columns, types))))

        self._schema = schema
        self.header = header
예제 #26
0
 def keyfilter(self, predicate):
     return fdict(cytoolz.keyfilter(predicate, self))
예제 #27
0
    def evaluate(
        self,
        dataset: DataPanel,
        input_columns: List[str],
        output_columns: List[str],
        batch_size: int = 32,
        metrics: List[str] = None,
        coerce_fn: Callable = None,
    ):

        # TODO(karan): generalize to TF2

        # Reset the dataset format
        dataset.reset_format()
        dataset.set_format(columns=input_columns + output_columns)

        # TODO(karan): check that the DataPanel conforms to the task definition
        # TODO(karan): figure out how the output_columns will be used by the metrics
        pass

        predictions = []
        targets = []

        # Loop and apply the prediction function
        # TODO(karan): not using .map() here in order to get more fine-grained
        #  control over devices
        for idx in range(0, len(dataset), batch_size):
            # Create the batch
            batch = dataset[idx:idx + batch_size]

            # Predict on the batch
            prediction_dict = self.predict_batch(batch=batch,
                                                 input_columns=input_columns)

            # Coerce the predictions
            if coerce_fn:
                prediction_dict = coerce_fn(prediction_dict)

            # Grab the raw target key/values
            target_dict = tz.keyfilter(lambda k: k in output_columns, batch)

            # TODO(karan): general version for non-classification problems
            # TODO(karan): move this to the right device
            if self.is_classifier:
                target_dict = tz.valmap(lambda v: torch.tensor(v), target_dict)

            # TODO(karan): incremental metric computation here
            # Append the predictions and targets
            predictions.append(prediction_dict)
            targets.append(target_dict)

        # Consolidate the predictions and targets
        if self.is_classifier:
            # TODO(karan): Need to store predictions and outputs from the model
            predictions = tz.merge_with(lambda v: torch.cat(v).to("cpu"),
                                        *predictions)
            targets = tz.merge_with(lambda v: torch.cat(v).to("cpu"), *targets)
        else:
            predictions = tz.merge_with(
                lambda x: list(itertools.chain.from_iterable(x)), *predictions)
            targets = tz.merge_with(
                lambda x: list(itertools.chain.from_iterable(x)), *targets)

        # Compute the metrics
        # TODO(karan): generalize this code to support metric computation for any task

        # Assumes classification, so the output_columns contains a single key for the
        # label
        if self.is_classifier:
            assert len(
                output_columns) == 1  # , "Only supports classification."
            num_classes = self.task.output_schema.features[list(
                self.task.output_schema.columns)[0]].num_classes

        labels = targets[list(targets.keys())[0]]

        if metrics is None:
            if self.task is None:
                raise ValueError(
                    "Must specify metrics if model not associated with task")
            metrics = self.task.metrics

        pred = predictions["pred"].to(self.device)
        target = labels.to(self.device)

        evaluation_dict = {
            metric: compute_metric(metric, pred, target, num_classes)
            for metric in metrics
        }

        # Reset the data format
        dataset.reset_format()

        return evaluation_dict
예제 #28
0
파일: utils.py 프로젝트: ys-zhang/corona
 def re_select(self, patten):
     patten = re.compile(patten)
     return list(keyfilter(patten.fullmatch, self._modules).values())
예제 #29
0
파일: text.py 프로젝트: dav009/entivaluator
def parse_gs_line(file_line):
    data = json.loads(file_line)
    return keyfilter(lambda k: k in ["text", "docId"], data)
예제 #30
0
파일: caches.py 프로젝트: kivo360/jamboree
def omit(blacklist, d):
    return keyfilter(lambda k: k not in blacklist, d)
예제 #31
0
 def uncached_example(cls, example: Dict, copy=True) -> Dict:
     """Return example with the "cache" and "slices" columns removed."""
     return tz.keyfilter(
         lambda k: k not in ["cache", "slices"],
         deepcopy(example) if copy else example,
     )
예제 #32
0
 def uncached_batch(cls, batch: Batch, copy=True) -> Batch:
     """Return batch with the "cache" and "slices" columns removed."""
     return tz.keyfilter(lambda k: k not in ["cache", "slices"],
                         deepcopy(batch) if copy else batch)
예제 #33
0
파일: genesig.py 프로젝트: simonvh/pySCENIC
 def _intersection_impl(self, other):
     genes = set(self.gene2weight.keys()).intersection(
         set(other.gene2weight.keys()))
     return frozendict(
         keyfilter(lambda k: k in genes,
                   merge_with(max, self.gene2weight, other.gene2weight)))
예제 #34
0
파일: text.py 프로젝트: dav009/entivaluator
def parse_gs_line(file_line):
    data = json.loads(file_line)
    return keyfilter(lambda k: k in ["text", "docId"], data)
예제 #35
0
 def get_retriable_missing(self) -> List[Hash32]:
     retriable = cytoolz.keyfilter(
         lambda k: time.time() - k > self.reply_timeout, self.missing)
     self.missing = cytoolz.dissoc(self.missing, *retriable.keys())
     return list(cytoolz.concat(retriable.values()))