예제 #1
0
def convert(sig: Type[GeneSignature],
            db: RegionRankingDatabase,
            delineation: Delineation,
            fraction: float = 0.80) -> Type[GeneSignature]:
    """
    Convert a signature of gene symbols to a signature of region identifiers.

    :param sig: The signature of gene symbols.
    :param db: The region database.
    :param delineation: The regulatory region delineation for genes.
    :param fraction: The fraction of overlap to take into account.
    :return: The signature of region identifiers.
    """
    assert sig
    assert db
    assert delineation

    # Every gene is transformed into a dictionary that maps the name of a feature to the weight of the corresponding gene.
    # These mappings are then combined taking the maximum of multiple values exists for a key.
    identifier2weight = merge_with(max, (dict(
        zip(
            map(
                attrgetter('name'),
                db.regions.intersection(load(delineation).get(gene),
                                        fraction=fraction)),
            repeat(weight))) for gene, weight in sig.gene2weight.items()))

    #region_identifiers = frozenset(
    #    map(attrgetter('name'),
    #        mapcat(list,
    #            map(partial(db.regions.intersection, fraction=fraction),
    #               map(load(delineation).get, sig.genes)))))
    return sig.copy(gene2weight=identifier2weight,
                    nomenclature=REGION_NOMENCLATURE)
예제 #2
0
파일: Candidates.py 프로젝트: webbcla/c2xg
    def merge_candidates(self, output_files, threshold):

        candidates = []
        print("Merging " + str(len(output_files)) + " files.")

        #Load
        for dict_file in output_files:
            try:
                candidates.append(self.Loader.load_file(dict_file))
            except Exception as e:
                print("ERROR")
                print(e)

        #Merge
        candidates = ct.merge_with(sum, [x for x in candidates])
        print("\tTOTAL CANDIDATES BEFORE PRUNING: " +
              str(len(list(candidates.keys()))))

        #Prune
        above_threshold = lambda x: x > threshold
        candidates = ct.valfilter(above_threshold, candidates)
        print("\tTOTAL CANDIDATES AFTER PRUNING: " +
              str(len(list(candidates.keys()))))

        return candidates
예제 #3
0
def recmerge(*objs, merge_sequences=False):
    """Recursively merge an arbitrary number of collections. For conflicting
    values, later collections to the right are given priority. By default
    (merge_sequences=False), sequences are treated as a normal value and not
    merged.

    Args:
        *objs: collections to merge
        merge_sequences: whether to merge values that are sequences

    Returns: merged collection
    """
    if isinstance(objs, tuple) and len(objs) == 1:
        # A squeeze operation since merge_with generates tuple(list_of_objs,)
        objs = objs[0]
    if all([isinstance(obj, Mapping) for obj in objs]):
        # Merges all the collections, recursively applies merging to the combined values
        return tz.merge_with(
            partial(recmerge, merge_sequences=merge_sequences), *objs)
    elif all([isinstance(obj, Sequence) for obj in objs]) and merge_sequences:
        # Merges sequence values by concatenation
        return list(tz.concat(objs))
    else:
        # If colls does not contain mappings, simply pick the last one
        return tz.last(objs)
예제 #4
0
def batchify_data_records(data_records: List[api.DataRecord]) -> api.BatchDataRecords:
    """Stack a list of DataRecord into BatchRecord. This process converts a list
    of tuples comprising of dicts {str: float/array} into tuples of dict {str: array}.
    Float/array is concatenated along the first dimension. See Example.

    Args:
        data_records: list[DataRecord], list of individual data records.

    Returns:
        BatchDataRecords, batch data records.

    Example:
    ::

        data_record_1 = ({"input_1": 1, "input_2": 2}, {"output_1": 3})
        data_record_2 = ({"input_1": 2, "input_2": 4}, {"output_1": 6})
        batch_data_records = (
            {"input_1": arr([1, 2], "input_2": arr([2, 4])},
            {"output_1": arr([3, 6])}
        )
    """
    batch_data_records = tuple(
        cytoolz.merge_with(np.array, ii) for ii in zip(*data_records)
    )
    return batch_data_records  # type: ignore
예제 #5
0
 def interleave(cls, datasets: List[Dataset],
                identifier: Identifier) -> Dataset:
     """Interleave a list of datasets."""
     return cls.from_batch(
         tz.merge_with(tz.interleave,
                       *[dataset[:] for dataset in datasets]),
         identifier=identifier,
     )
예제 #6
0
 def from_batches(cls,
                  batches: Sequence[Batch],
                  identifier: Identifier = None) -> Dataset:
     """Convert a list of batches to a dataset."""
     return cls.from_batch(
         tz.merge_with(tz.concat, *batches),
         identifier=identifier,
     )
예제 #7
0
def concatenate_tweets(tweets, dictionary, characterizer, step=10000):
    if not tweets.count():
        return []

    all_bows = []

    for tweets in partition_all(step, queryset_iterator(tweets.only('text'), chunksize=step)):
        bows = []

        for tweet in tweets:
            keywords = characterizer.tokenize(tweet.text)
            bow = dictionary.doc2bow(set(keywords), allow_update=False)
            bows.append(dict(bow))

        all_bows.append(merge_with(sum, *bows))

    return list(merge_with(sum, all_bows).items())
예제 #8
0
파일: Association.py 프로젝트: webbcla/c2xg
	def merge_ngrams(self, files = None, n_gram_threshold = 1):
		
		all_ngrams = []
		
		#Get a list of ngram files
		if files == None:
			files = self.Loader.list_output(type = "ngrams")
			
		#Break into lists of 20 files
		file_list = ct.partition_all(20, files)
		
		for files in file_list:
			
			ngrams = []		#Initialize holding list
			
			#Load
			for dict_file in files:
				try:
					ngrams.append(self.Loader.load_file(dict_file))
				except:
					print("Not loading " + str(dict_file))
		
			#Merge
			ngrams = ct.merge_with(sum, [x for x in ngrams])
		
			print("\tSUB-TOTAL NGRAMS: " + str(len(list(ngrams.keys()))))
			print("\tSUB-TOTAL WORDS: " + str(ngrams["TOTAL"]))
			print("\n")
			
			all_ngrams.append(ngrams)
			
		#Now merge everything
		all_ngrams = ct.merge_with(sum, [x for x in all_ngrams])
		
		print("\tTOTAL NGRAMS: " + str(len(list(all_ngrams.keys()))))
		print("\tTOTAL WORDS: " + str(all_ngrams["TOTAL"]))
		
		#Now enforce threshold
		keepable = lambda x: x > n_gram_threshold
		all_ngrams = ct.valfilter(keepable, all_ngrams)
		
		print("\tAfter pruning:")
		print("\tTOTAL NGRAMS: " + str(len(list(all_ngrams.keys()))))
		
		return all_ngrams
예제 #9
0
def concatenate_tweets(tweets, dictionary, characterizer, step=10000):
    if not tweets.count():
        return []

    all_bows = []

    for tweets in partition_all(
            step, queryset_iterator(tweets.only('text'), chunksize=step)):
        bows = []

        for tweet in tweets:
            keywords = characterizer.tokenize(tweet.text)
            bow = dictionary.doc2bow(set(keywords), allow_update=False)
            bows.append(dict(bow))

        all_bows.append(merge_with(sum, *bows))

    return list(merge_with(sum, all_bows).items())
예제 #10
0
def test_curried_namespace():
    def should_curry(func):
        if not callable(func) or isinstance(func, curry):
            return False
        nargs = enhanced_num_required_args(func)
        if nargs is None or nargs > 1:
            return True
        else:
            return nargs == 1 and enhanced_has_keywords(func)

    def curry_namespace(ns):
        return dict((
            name,
            curry(f) if should_curry(f) else f,
        ) for name, f in ns.items() if '__' not in name)

    all_auto_curried = curry_namespace(vars(eth_utils))

    inferred_namespace = valfilter(callable, all_auto_curried)
    curried_namespace = valfilter(callable, eth_utils.curried.__dict__)

    if inferred_namespace != curried_namespace:
        missing = set(inferred_namespace) - set(curried_namespace)
        if missing:
            to_insert = sorted("%s," % f for f in missing)
            raise AssertionError(
                'There are missing functions in eth_utils.curried:\n' +
                '\n'.join(to_insert))
        extra = set(curried_namespace) - set(inferred_namespace)
        if extra:
            raise AssertionError(
                'There are extra functions in eth_utils.curried:\n' +
                '\n'.join(sorted(extra)))
        unequal = merge_with(list, inferred_namespace, curried_namespace)
        unequal = valfilter(lambda x: x[0] != x[1], unequal)
        to_curry = keyfilter(lambda x: should_curry(getattr(eth_utils, x)),
                             unequal)
        if to_curry:
            to_curry_formatted = sorted('{0} = curry({0})'.format(f)
                                        for f in to_curry)
            raise AssertionError(
                'There are missing functions to curry in eth_utils.curried:\n'
                + '\n'.join(to_curry_formatted))
        elif unequal:
            not_to_curry_formatted = sorted(unequal)
            raise AssertionError(
                'Missing functions NOT to curry in eth_utils.curried:\n' +
                '\n'.join(not_to_curry_formatted))
        else:
            raise AssertionError("unexplained difference between %r and %r" % (
                inferred_namespace,
                curried_namespace,
            ))
예제 #11
0
 def getFuncsIndex(self):
     for elt in ast.walk(self.node):
         if isinstance(elt, ast.Name):
             indice = self.getIndice(elt)
             if hasattr(elt, "func"):
                 self.funcs = merge_with(lambda x: reduce(add, x, []),
                                         self.funcs, {elt.func: [indice]})
     for key in self.funcs.keys():
         self.funcsIndex[key] = itemmap(lambda kv, key=key:
                                        (kv[0],
                                         (int(kv[0] in self.funcs[key]), )),
                                        self.index)
예제 #12
0
파일: genesig.py 프로젝트: verohan/pySCENIC
    def union(self, other: Type['GeneSignature']) -> Type['GeneSignature']:
        """
        Creates a new :class:`GeneSignature` instance which is the union of this signature and the other supplied
        signature.

        The weight associated with the genes in the intersection is the maximum of the weights in the composing signatures.

        :param other: The other :class:`GeneSignature`.
        :return: the new :class:`GeneSignature` instance.
        """
        return self.copy(name="({} | {})".format(self.name, other.name) if self.name != other.name else self.name,
                             gene2weight=frozendict(merge_with(max, self.gene2weight, other.gene2weight)))
예제 #13
0
 def chain(
     cls,
     datasets: List[Dataset],
     identifier: Identifier,
 ) -> Dataset:
     """Chain a list of datasets."""
     return cls.from_batch(
         tz.merge_with(
             tz.compose(list, tz.concat),
             *[dataset[:] for dataset in datasets],
         ),
         identifier=identifier,
     )
예제 #14
0
def test_class_sigs():
    """Test that all ``cdef class`` extension types in ``cytoolz`` have
    correctly embedded the function signature as done in ``toolz``.
    """
    import toolz

    # only consider items created in both `toolz` and `cytoolz`
    toolz_dict = valfilter(isfrommod("toolz"), toolz.__dict__)
    cytoolz_dict = valfilter(isfrommod("cytoolz"), cytoolz.__dict__)

    # only test `cdef class` extensions from `cytoolz`
    cytoolz_dict = valfilter(
        lambda x: not isinstance(x, BuiltinFunctionType), cytoolz_dict
    )

    # full API coverage should be tested elsewhere
    toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict)
    cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict)

    d = merge_with(identity, toolz_dict, cytoolz_dict)
    for key, (toolz_func, cytoolz_func) in d.items():
        if key in ["excepts", "juxt", "memoize", "flip"]:
            continue
        try:
            # function
            toolz_spec = inspect.getargspec(toolz_func)
        except TypeError:
            try:
                # curried or partial object
                toolz_spec = inspect.getargspec(toolz_func.func)
            except (TypeError, AttributeError):
                # class
                toolz_spec = inspect.getargspec(toolz_func.__init__)

        # For Cython < 0.25
        toolz_sig = toolz_func.__name__ + inspect.formatargspec(*toolz_spec)
        doc = cytoolz_func.__doc__
        # For Cython >= 0.25
        toolz_sig_alt = toolz_func.__name__ + inspect.formatargspec(
            *toolz_spec,
            **{"formatvalue": lambda x: "=" + getattr(x, "__name__", repr(x))}
        )
        doc_alt = doc.replace("Py_ssize_t ", "")
        if not (toolz_sig in doc or toolz_sig_alt in doc_alt):
            message = (
                "cytoolz.%s does not have correct function signature."
                "\n\nExpected: %s"
                "\n\nDocstring in cytoolz is:\n%s"
                % (key, toolz_sig, cytoolz_func.__doc__)
            )
            assert False, message
예제 #15
0
파일: genesig.py 프로젝트: verohan/pySCENIC
    def intersection(self, other: Type['GeneSignature']) -> Type['GeneSignature']:
        """
        Creates a new :class:`GeneSignature` instance which is the intersection of this signature and the supplied other
        signature.

        The weight associated with the genes in the intersection is the maximum of the weights in the composing signatures.

        :param other: The other :class:`GeneSignature`.
        :return: the new :class:`GeneSignature` instance.
        """
        genes = set(self.gene2weight.keys()).intersection(set(other.gene2weight.keys()))
        return self.copy(name="({} & {})".format(self.name, other.name) if self.name != other.name else self.name,
                         gene2weight=frozendict(keyfilter(lambda k: k in genes,
                                                          merge_with(max, self.gene2weight, other.gene2weight))))
예제 #16
0
def test_curried_namespace():
    exceptions = import_module("cytoolz.curried.exceptions")
    namespace = {}

    def should_curry(func):
        if not callable(func) or isinstance(func, cytoolz.curry):
            return False
        nargs = cytoolz.functoolz.num_required_args(func)
        if nargs is None or nargs > 1:
            return True
        return nargs == 1 and cytoolz.functoolz.has_keywords(func)

    def curry_namespace(ns):
        return {
            name: cytoolz.curry(f) if should_curry(f) else f
            for name, f in ns.items() if "__" not in name
        }

    from_cytoolz = curry_namespace(vars(cytoolz))
    from_exceptions = curry_namespace(vars(exceptions))
    namespace.update(cytoolz.merge(from_cytoolz, from_exceptions))

    namespace = cytoolz.valfilter(callable, namespace)
    curried_namespace = cytoolz.valfilter(callable, cytoolz.curried.__dict__)

    if namespace != curried_namespace:
        missing = set(namespace) - set(curried_namespace)
        if missing:
            raise AssertionError(
                "There are missing functions in cytoolz.curried:\n    %s" %
                "    \n".join(sorted(missing)))
        extra = set(curried_namespace) - set(namespace)
        if extra:
            raise AssertionError(
                "There are extra functions in cytoolz.curried:\n    %s" %
                "    \n".join(sorted(extra)))
        unequal = cytoolz.merge_with(list, namespace, curried_namespace)
        unequal = cytoolz.valfilter(lambda x: x[0] != x[1], unequal)
        messages = []
        for name, (orig_func, auto_func) in sorted(unequal.items()):
            if name in from_exceptions:
                messages.append(
                    "%s should come from cytoolz.curried.exceptions" % name)
            elif should_curry(getattr(cytoolz, name)):
                messages.append("%s should be curried from cytoolz" % name)
            else:
                messages.append(
                    "%s should come from cytoolz and NOT be curried" % name)
        raise AssertionError("\n".join(messages))
예제 #17
0
    def from_batches(
        cls,
        batches: Sequence[Batch],
        identifier: Identifier = None,
        dataset_fmt: str = "in_memory",
    ) -> Dataset:
        """Convert a list of batches to a dataset."""

        return cls.from_batch(
            tz.merge_with(
                tz.compose(list, tz.concat),
                *batches,
            ),
            identifier=identifier,
            dataset_fmt=dataset_fmt,
        )
예제 #18
0
    def __iter__(self):
        self.order = []
        user_ids = list(set(self.tweets.values_list(self.key, flat=True)))

        for user_pks in partition_all(self.step, user_ids):
            queryset = self.tweets.filter(user_id__in=user_pks)
            for key_id, tweet_set in groupby(queryset.only('text', self.key).order_by(self.key), key=self.key_func):
                bows = []

                for tweet in tweet_set:
                    keywords = self.characterizer.tokenize(tweet.text)
                    bow = self.dictionary.doc2bow(set(keywords), allow_update=False)
                    bows.append(dict(bow))

                self.order.append(key_id)
                yield list(merge_with(sum, *bows).items())
예제 #19
0
def test_class_sigs():
    """ Test that all ``cdef class`` extension types in ``cytoolz`` have
        correctly embedded the function signature as done in ``toolz``.
    """
    import toolz
    # only consider items created in both `toolz` and `cytoolz`
    toolz_dict = valfilter(isfrommod('toolz'), toolz.__dict__)
    cytoolz_dict = valfilter(isfrommod('cytoolz'), cytoolz.__dict__)

    # only test `cdef class` extensions from `cytoolz`
    cytoolz_dict = valfilter(lambda x: not isinstance(x, BuiltinFunctionType),
                             cytoolz_dict)

    # full API coverage should be tested elsewhere
    toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict)
    cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict)

    d = merge_with(identity, toolz_dict, cytoolz_dict)
    for key, (toolz_func, cytoolz_func) in d.items():
        if key in ['excepts', 'juxt', 'memoize', 'flip']:
            continue
        try:
            # function
            toolz_spec = inspect.getargspec(toolz_func)
        except TypeError:
            try:
                # curried or partial object
                toolz_spec = inspect.getargspec(toolz_func.func)
            except (TypeError, AttributeError):
                # class
                toolz_spec = inspect.getargspec(toolz_func.__init__)

        # For Cython < 0.25
        toolz_sig = toolz_func.__name__ + inspect.formatargspec(*toolz_spec)
        doc = cytoolz_func.__doc__
        # For Cython >= 0.25
        toolz_sig_alt = toolz_func.__name__ + inspect.formatargspec(
            *toolz_spec,
            **{'formatvalue': lambda x: '=' + getattr(x, '__name__', repr(x))}
        )
        doc_alt = doc.replace('Py_ssize_t ', '')
        if not (toolz_sig in doc or toolz_sig_alt in doc_alt):
            message = ('cytoolz.%s does not have correct function signature.'
                       '\n\nExpected: %s'
                       '\n\nDocstring in cytoolz is:\n%s'
                       % (key, toolz_sig, cytoolz_func.__doc__))
            assert False, message
예제 #20
0
    def collect(self):
        from prometheus_client.core import GaugeMetricFamily, CounterMetricFamily

        yield GaugeMetricFamily(
            "dask_scheduler_clients",
            "Number of clients connected.",
            value=len(self.server.clients),
        )

        yield GaugeMetricFamily(
            "dask_scheduler_desired_workers",
            "Number of workers scheduler needs for task graph.",
            value=self.server.adaptive_target(),
        )

        worker_states = GaugeMetricFamily(
            "dask_scheduler_workers",
            "Number of workers known by scheduler.",
            labels=["state"],
        )
        worker_states.add_metric(["connected"], len(self.server.workers))
        worker_states.add_metric(["saturated"], len(self.server.saturated))
        worker_states.add_metric(["idle"], len(self.server.idle))
        yield worker_states

        tasks = GaugeMetricFamily(
            "dask_scheduler_tasks",
            "Number of tasks known by scheduler.",
            labels=["state"],
        )

        task_counter = merge_with(
            sum, (tp.states for tp in self.server.task_prefixes.values()))

        yield CounterMetricFamily(
            "dask_scheduler_tasks_forgotten",
            ("Total number of processed tasks no longer in memory and already "
             "removed from the scheduler job queue. Note task groups on the "
             "scheduler which have all tasks in the forgotten state are not included."
             ),
            value=task_counter.get("forgotten", 0.0),
        )

        for state in ALL_TASK_STATES:
            tasks.add_metric([state], task_counter.get(state, 0.0))
        yield tasks
예제 #21
0
def test_curried_namespace():
    exceptions = import_module('cytoolz.curried.exceptions')
    namespace = {}

    def should_curry(func):
        if not callable(func) or isinstance(func, cytoolz.curry):
            return False
        nargs = cytoolz.functoolz.num_required_args(func)
        if nargs is None or nargs > 1:
            return True
        return nargs == 1 and cytoolz.functoolz.has_keywords(func)


    def curry_namespace(ns):
        return dict(
            (name, cytoolz.curry(f) if should_curry(f) else f)
            for name, f in ns.items() if '__' not in name
        )

    from_cytoolz = curry_namespace(vars(cytoolz))
    from_exceptions = curry_namespace(vars(exceptions))
    namespace.update(cytoolz.merge(from_cytoolz, from_exceptions))

    namespace = cytoolz.valfilter(callable, namespace)
    curried_namespace = cytoolz.valfilter(callable, cytoolz.curried.__dict__)

    if namespace != curried_namespace:
        missing = set(namespace) - set(curried_namespace)
        if missing:
            raise AssertionError('There are missing functions in cytoolz.curried:\n    %s'
                                 % '    \n'.join(sorted(missing)))
        extra = set(curried_namespace) - set(namespace)
        if extra:
            raise AssertionError('There are extra functions in cytoolz.curried:\n    %s'
                                 % '    \n'.join(sorted(extra)))
        unequal = cytoolz.merge_with(list, namespace, curried_namespace)
        unequal = cytoolz.valfilter(lambda x: x[0] != x[1], unequal)
        messages = []
        for name, (orig_func, auto_func) in sorted(unequal.items()):
            if name in from_exceptions:
                messages.append('%s should come from cytoolz.curried.exceptions' % name)
            elif should_curry(getattr(cytoolz, name)):
                messages.append('%s should be curried from cytoolz' % name)
            else:
                messages.append('%s should come from cytoolz and NOT be curried' % name)
        raise AssertionError('\n'.join(messages))
예제 #22
0
    def __iter__(self):
        self.order = []
        user_ids = list(set(self.tweets.values_list(self.key, flat=True)))

        for user_pks in partition_all(self.step, user_ids):
            queryset = self.tweets.filter(user_id__in=user_pks)
            for key_id, tweet_set in groupby(queryset.only(
                    'text', self.key).order_by(self.key),
                                             key=self.key_func):
                bows = []

                for tweet in tweet_set:
                    keywords = self.characterizer.tokenize(tweet.text)
                    bow = self.dictionary.doc2bow(set(keywords),
                                                  allow_update=False)
                    bows.append(dict(bow))

                self.order.append(key_id)
                yield list(merge_with(sum, *bows).items())
예제 #23
0
def create_slice(args):
    # Unpack args
    dataset, slice_membership, slice_batches, i, batch_size, slice_cache_hash = args

    # Create a new empty slice
    sl = Slice.from_dict({})

    # Create a Slice "copy" of the Dataset
    sl.__dict__.update(dataset.__dict__)
    sl._identifier = None

    # Filter
    sl = sl.filter(
        lambda example, idx: bool(slice_membership[idx, i]),
        with_indices=True,
        input_columns=["index"],
        batch_size=batch_size,
        cache_file_name=str(
            dataset.logdir / ("cache-" + str(abs(slice_cache_hash)) + "-filter.arrow")
        ),
    )

    slice_batch = tz.merge_with(tz.compose(list, tz.concat), slice_batches)

    # FIXME(karan): interaction tape history is wrong here, esp with augmenation/attacks

    # Map
    if len(sl):
        sl = sl.map(
            lambda batch, indices: tz.valmap(
                lambda v: v[indices[0] : indices[0] + batch_size], slice_batch
            ),
            batched=True,
            batch_size=batch_size,
            with_indices=True,
            remove_columns=sl.column_names,
            cache_file_name=str(
                dataset.logdir / ("cache-" + str(abs(slice_cache_hash)) + ".arrow")
            ),
        )

    return sl
예제 #24
0
def test_docstrings_uptodate():
    import toolz

    differ = difflib.Differ()

    # only consider items created in both `toolz` and `cytoolz`
    toolz_dict = valfilter(isfrommod("toolz"), toolz.__dict__)
    cytoolz_dict = valfilter(isfrommod("cytoolz"), cytoolz.__dict__)

    # only test functions that have docstrings defined in `toolz`
    toolz_dict = valfilter(lambda x: getattr(x, "__doc__", ""), toolz_dict)

    # full API coverage should be tested elsewhere
    toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict)
    cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict)

    d = merge_with(identity, toolz_dict, cytoolz_dict)
    for key, (toolz_func, cytoolz_func) in d.items():
        # only check if the new doctstring *contains* the expected docstring
        toolz_doc = convertdoc(toolz_func)
        cytoolz_doc = cytoolz_func.__doc__
        if toolz_doc not in cytoolz_doc:
            diff = list(
                differ.compare(toolz_doc.splitlines(),
                               cytoolz_doc.splitlines()))
            fulldiff = list(diff)
            # remove additional lines at the beginning
            while diff and diff[0].startswith("+"):
                diff.pop(0)
            # remove additional lines at the end
            while diff and diff[-1].startswith("+"):
                diff.pop()

            def checkbad(line):
                return line.startswith("+") and not (
                    "# doctest: +SKIP" in line and key in skipped_doctests)

            if any(map(checkbad, diff)):
                assert False, "Error: cytoolz.%s has a bad docstring:\n%s\n" % (
                    key,
                    "\n".join(fulldiff),
                )
예제 #25
0
def test_class_sigs():
    """ Test that all ``cdef class`` extension types in ``cytoolz`` have
        correctly embedded the function signature as done in ``toolz``.
    """
    import toolz
    # only consider items created in both `toolz` and `cytoolz`
    toolz_dict = valfilter(isfrommod('toolz'), toolz.__dict__)
    cytoolz_dict = valfilter(isfrommod('cytoolz'), cytoolz.__dict__)

    # only test `cdef class` extensions from `cytoolz`
    cytoolz_dict = valfilter(lambda x: not isinstance(x, BuiltinFunctionType),
                             cytoolz_dict)

    # full API coverage should be tested elsewhere
    toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict)
    cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict)

    d = merge_with(identity, toolz_dict, cytoolz_dict)
    for key, (toolz_func, cytoolz_func) in d.items():
        if key in ['excepts', 'juxt']:
            continue
        try:
            # function
            toolz_spec = inspect.getargspec(toolz_func)
        except TypeError:
            try:
                # curried or partial object
                toolz_spec = inspect.getargspec(toolz_func.func)
            except (TypeError, AttributeError):
                # class
                toolz_spec = inspect.getargspec(toolz_func.__init__)

        toolz_sig = toolz_func.__name__ + inspect.formatargspec(*toolz_spec)
        if toolz_sig not in cytoolz_func.__doc__:
            message = ('cytoolz.%s does not have correct function signature.'
                       '\n\nExpected: %s'
                       '\n\nDocstring in cytoolz is:\n%s'
                       % (key, toolz_sig, cytoolz_func.__doc__))
            assert False, message
예제 #26
0
def test_docstrings_uptodate():
    import toolz
    differ = difflib.Differ()

    # only consider items created in both `toolz` and `cytoolz`
    toolz_dict = valfilter(isfrommod('toolz'), toolz.__dict__)
    cytoolz_dict = valfilter(isfrommod('cytoolz'), cytoolz.__dict__)

    # only test functions that have docstrings defined in `toolz`
    toolz_dict = valfilter(lambda x: getattr(x, '__doc__', ''), toolz_dict)

    # full API coverage should be tested elsewhere
    toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict)
    cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict)

    d = merge_with(identity, toolz_dict, cytoolz_dict)
    for key, (toolz_func, cytoolz_func) in d.items():
        # only check if the new doctstring *contains* the expected docstring
        toolz_doc = convertdoc(toolz_func)
        cytoolz_doc = cytoolz_func.__doc__
        if toolz_doc not in cytoolz_doc:
            diff = list(differ.compare(toolz_doc.splitlines(),
                                       cytoolz_doc.splitlines()))
            fulldiff = list(diff)
            # remove additional lines at the beginning
            while diff and diff[0].startswith('+'):
                diff.pop(0)
            # remove additional lines at the end
            while diff and diff[-1].startswith('+'):
                diff.pop()

            def checkbad(line):
                return (line.startswith('+') and
                        not ('# doctest: +SKIP' in line and
                             key in skipped_doctests))

            if any(map(checkbad, diff)):
                assert False, 'Error: cytoolz.%s has a bad docstring:\n%s\n' % (
                    key, '\n'.join(fulldiff))
예제 #27
0
def get_vocab(df, phraser=None, stop=None, nlp=None, column="Text", workers=1):
    """
    Gets vocab
    :param df:
    :return:
    """
    chunksize = int(len(df) / workers)

    pool_instance = mp.Pool(processes=workers, maxtasksperchild=1)
    vocab = pool_instance.map(partial(process_vocab,
                                      phraser=phraser,
                                      stop=stop,
                                      nlp=nlp),
                              ct.partition(chunksize, df.loc[:,
                                                             column].values),
                              chunksize=1)
    pool_instance.close()
    pool_instance.join()

    vocab = ct.merge_with(sum, vocab)

    return vocab
예제 #28
0
	def merge_candidates(self, output_files, threshold):
		
		candidates = []
		print("Merging " + str(len(output_files)) + " files.")
		
		#Load
		for dict_file in output_files:
			try:
				candidates.append(self.Loader.load_file(dict_file))
			except Exception as e:
				print("ERROR")
				print(e)
		
		#Merge
		candidates = ct.merge_with(sum, [x for x in candidates])
		print("\tTOTAL CANDIDATES BEFORE PRUNING: " + str(len(list(candidates.keys()))))
		
		#Prune
		above_threshold = lambda x: x > threshold
		candidates = ct.valfilter(above_threshold, candidates)
		print("\tTOTAL CANDIDATES AFTER PRUNING: " + str(len(list(candidates.keys()))))
		
		return candidates
예제 #29
0
def test_class_sigs():
    """ Test that all ``cdef class`` extension types in ``cytoolz`` have
        correctly embedded the function signature as done in ``toolz``.
    """
    import toolz
    # only consider items created in both `toolz` and `cytoolz`
    toolz_dict = valfilter(isfrommod('toolz'), toolz.__dict__)
    cytoolz_dict = valfilter(isfrommod('cytoolz'), cytoolz.__dict__)

    # only test `cdef class` extensions from `cytoolz`
    cytoolz_dict = valfilter(lambda x: not isinstance(x, BuiltinFunctionType),
                             cytoolz_dict)

    # full API coverage should be tested elsewhere
    toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict)
    cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict)

    d = merge_with(identity, toolz_dict, cytoolz_dict)
    for key, (toolz_func, cytoolz_func) in d.items():
        try:
            # function
            toolz_spec = inspect.getargspec(toolz_func)
        except TypeError:
            try:
                # curried or partial object
                toolz_spec = inspect.getargspec(toolz_func.func)
            except (TypeError, AttributeError):
                # class
                toolz_spec = inspect.getargspec(toolz_func.__init__)

        toolz_sig = toolz_func.__name__ + inspect.formatargspec(*toolz_spec)
        if toolz_sig not in cytoolz_func.__doc__:
            message = ('cytoolz.%s does not have correct function signature.'
                       '\n\nExpected: %s'
                       '\n\nDocstring in cytoolz is:\n%s'
                       % (key, toolz_sig, cytoolz_func.__doc__))
            assert False, message
예제 #30
0
 def from_batches(cls, batches: List[Batch]):
     """Create an AbstractDataset from a list of batches."""
     return cls.from_batch(
         tz.merge_with(tz.compose(list, tz.concat), *batches), )
예제 #31
0
def deep_merge(*dicts):
    return merge_with(merge_if_dicts, *dicts)
예제 #32
0
def merge_if_dicts(values: Sequence[Dict[Any, Any]]) -> Any:
    if all(isinstance(item, Mapping) for item in values):
        return merge_with(merge_if_dicts, *values)
    else:
        return values[-1]
예제 #33
0
def deep_merge(*dicts: Dict[Any, Any]) -> Dict[Any, Any]:
    return merge_with(merge_if_dicts, *dicts)
예제 #34
0
    def __init__(
        self,
        *args,
        column_names: List[str] = None,
        info: DatasetInfo = None,
        split: Optional[NamedSplit] = None,
    ):

        # Data is a dictionary of lists
        self._data = {}

        # Single argument
        if len(args) == 1:
            assert column_names is None, "Don't pass in column_names."
            # The data is passed in
            data = args[0]

            # `data` is a dictionary
            if isinstance(data, dict) and len(data):
                # Assert all columns are the same length
                self._assert_columns_all_equal_length(data)
                self._data = data

            # `data` is a list
            elif isinstance(data, list) and len(data):
                # Transpose the list of dicts to a dict of lists i.e. a batch
                data = tz.merge_with(list, *data)
                # Assert all columns are the same length
                self._assert_columns_all_equal_length(data)
                self._data = data

            # `data` is a datasets.Dataset
            elif isinstance(data, datasets.Dataset):
                self._data = data[:]
                info, split = data.info, data.split

        # No argument
        elif len(args) == 0:

            # Use column_names to setup the data dictionary
            if column_names:
                self._data = {k: [] for k in column_names}

        # Setup the DatasetInfo
        info = info.copy() if info is not None else DatasetInfo()
        AbstractDataset.__init__(self, info=info, split=split)

        # Create attributes for all columns and visible columns
        self.all_columns = list(self._data.keys())
        self.visible_columns = None

        # Create attributes for visible rows
        self.visible_rows = None

        # Initialization
        self._initialize_state()

        logger.info(
            f"Created `InMemoryDataset` with {len(self)} rows and "
            f"{len(self.column_names)} columns."
        )
예제 #35
0
    def evaluate(
        self,
        dataset: DataPanel,
        input_columns: List[str],
        output_columns: List[str],
        batch_size: int = 32,
        metrics: List[str] = None,
        coerce_fn: Callable = None,
    ):

        # TODO(karan): generalize to TF2

        # Reset the dataset format
        dataset.reset_format()
        dataset.set_format(columns=input_columns + output_columns)

        # TODO(karan): check that the DataPanel conforms to the task definition
        # TODO(karan): figure out how the output_columns will be used by the metrics
        pass

        predictions = []
        targets = []

        # Loop and apply the prediction function
        # TODO(karan): not using .map() here in order to get more fine-grained
        #  control over devices
        for idx in range(0, len(dataset), batch_size):
            # Create the batch
            batch = dataset[idx:idx + batch_size]

            # Predict on the batch
            prediction_dict = self.predict_batch(batch=batch,
                                                 input_columns=input_columns)

            # Coerce the predictions
            if coerce_fn:
                prediction_dict = coerce_fn(prediction_dict)

            # Grab the raw target key/values
            target_dict = tz.keyfilter(lambda k: k in output_columns, batch)

            # TODO(karan): general version for non-classification problems
            # TODO(karan): move this to the right device
            if self.is_classifier:
                target_dict = tz.valmap(lambda v: torch.tensor(v), target_dict)

            # TODO(karan): incremental metric computation here
            # Append the predictions and targets
            predictions.append(prediction_dict)
            targets.append(target_dict)

        # Consolidate the predictions and targets
        if self.is_classifier:
            # TODO(karan): Need to store predictions and outputs from the model
            predictions = tz.merge_with(lambda v: torch.cat(v).to("cpu"),
                                        *predictions)
            targets = tz.merge_with(lambda v: torch.cat(v).to("cpu"), *targets)
        else:
            predictions = tz.merge_with(
                lambda x: list(itertools.chain.from_iterable(x)), *predictions)
            targets = tz.merge_with(
                lambda x: list(itertools.chain.from_iterable(x)), *targets)

        # Compute the metrics
        # TODO(karan): generalize this code to support metric computation for any task

        # Assumes classification, so the output_columns contains a single key for the
        # label
        if self.is_classifier:
            assert len(
                output_columns) == 1  # , "Only supports classification."
            num_classes = self.task.output_schema.features[list(
                self.task.output_schema.columns)[0]].num_classes

        labels = targets[list(targets.keys())[0]]

        if metrics is None:
            if self.task is None:
                raise ValueError(
                    "Must specify metrics if model not associated with task")
            metrics = self.task.metrics

        pred = predictions["pred"].to(self.device)
        target = labels.to(self.device)

        evaluation_dict = {
            metric: compute_metric(metric, pred, target, num_classes)
            for metric in metrics
        }

        # Reset the data format
        dataset.reset_format()

        return evaluation_dict
예제 #36
0
def merge_with(func, d, *dicts, **kwargs):
    return cytoolz.merge_with(func, d, *dicts, **kwargs)
예제 #37
0
def merge(left, right):
    return cytoolz.merge_with(sum, left, right)
예제 #38
0
def merge_if_dicts(values):
    if all(isinstance(item, Mapping) for item in values):
        return merge_with(merge_if_dicts, *values)
    else:
        return values[-1]