def convert(sig: Type[GeneSignature], db: RegionRankingDatabase, delineation: Delineation, fraction: float = 0.80) -> Type[GeneSignature]: """ Convert a signature of gene symbols to a signature of region identifiers. :param sig: The signature of gene symbols. :param db: The region database. :param delineation: The regulatory region delineation for genes. :param fraction: The fraction of overlap to take into account. :return: The signature of region identifiers. """ assert sig assert db assert delineation # Every gene is transformed into a dictionary that maps the name of a feature to the weight of the corresponding gene. # These mappings are then combined taking the maximum of multiple values exists for a key. identifier2weight = merge_with(max, (dict( zip( map( attrgetter('name'), db.regions.intersection(load(delineation).get(gene), fraction=fraction)), repeat(weight))) for gene, weight in sig.gene2weight.items())) #region_identifiers = frozenset( # map(attrgetter('name'), # mapcat(list, # map(partial(db.regions.intersection, fraction=fraction), # map(load(delineation).get, sig.genes))))) return sig.copy(gene2weight=identifier2weight, nomenclature=REGION_NOMENCLATURE)
def merge_candidates(self, output_files, threshold): candidates = [] print("Merging " + str(len(output_files)) + " files.") #Load for dict_file in output_files: try: candidates.append(self.Loader.load_file(dict_file)) except Exception as e: print("ERROR") print(e) #Merge candidates = ct.merge_with(sum, [x for x in candidates]) print("\tTOTAL CANDIDATES BEFORE PRUNING: " + str(len(list(candidates.keys())))) #Prune above_threshold = lambda x: x > threshold candidates = ct.valfilter(above_threshold, candidates) print("\tTOTAL CANDIDATES AFTER PRUNING: " + str(len(list(candidates.keys())))) return candidates
def recmerge(*objs, merge_sequences=False): """Recursively merge an arbitrary number of collections. For conflicting values, later collections to the right are given priority. By default (merge_sequences=False), sequences are treated as a normal value and not merged. Args: *objs: collections to merge merge_sequences: whether to merge values that are sequences Returns: merged collection """ if isinstance(objs, tuple) and len(objs) == 1: # A squeeze operation since merge_with generates tuple(list_of_objs,) objs = objs[0] if all([isinstance(obj, Mapping) for obj in objs]): # Merges all the collections, recursively applies merging to the combined values return tz.merge_with( partial(recmerge, merge_sequences=merge_sequences), *objs) elif all([isinstance(obj, Sequence) for obj in objs]) and merge_sequences: # Merges sequence values by concatenation return list(tz.concat(objs)) else: # If colls does not contain mappings, simply pick the last one return tz.last(objs)
def batchify_data_records(data_records: List[api.DataRecord]) -> api.BatchDataRecords: """Stack a list of DataRecord into BatchRecord. This process converts a list of tuples comprising of dicts {str: float/array} into tuples of dict {str: array}. Float/array is concatenated along the first dimension. See Example. Args: data_records: list[DataRecord], list of individual data records. Returns: BatchDataRecords, batch data records. Example: :: data_record_1 = ({"input_1": 1, "input_2": 2}, {"output_1": 3}) data_record_2 = ({"input_1": 2, "input_2": 4}, {"output_1": 6}) batch_data_records = ( {"input_1": arr([1, 2], "input_2": arr([2, 4])}, {"output_1": arr([3, 6])} ) """ batch_data_records = tuple( cytoolz.merge_with(np.array, ii) for ii in zip(*data_records) ) return batch_data_records # type: ignore
def interleave(cls, datasets: List[Dataset], identifier: Identifier) -> Dataset: """Interleave a list of datasets.""" return cls.from_batch( tz.merge_with(tz.interleave, *[dataset[:] for dataset in datasets]), identifier=identifier, )
def from_batches(cls, batches: Sequence[Batch], identifier: Identifier = None) -> Dataset: """Convert a list of batches to a dataset.""" return cls.from_batch( tz.merge_with(tz.concat, *batches), identifier=identifier, )
def concatenate_tweets(tweets, dictionary, characterizer, step=10000): if not tweets.count(): return [] all_bows = [] for tweets in partition_all(step, queryset_iterator(tweets.only('text'), chunksize=step)): bows = [] for tweet in tweets: keywords = characterizer.tokenize(tweet.text) bow = dictionary.doc2bow(set(keywords), allow_update=False) bows.append(dict(bow)) all_bows.append(merge_with(sum, *bows)) return list(merge_with(sum, all_bows).items())
def merge_ngrams(self, files = None, n_gram_threshold = 1): all_ngrams = [] #Get a list of ngram files if files == None: files = self.Loader.list_output(type = "ngrams") #Break into lists of 20 files file_list = ct.partition_all(20, files) for files in file_list: ngrams = [] #Initialize holding list #Load for dict_file in files: try: ngrams.append(self.Loader.load_file(dict_file)) except: print("Not loading " + str(dict_file)) #Merge ngrams = ct.merge_with(sum, [x for x in ngrams]) print("\tSUB-TOTAL NGRAMS: " + str(len(list(ngrams.keys())))) print("\tSUB-TOTAL WORDS: " + str(ngrams["TOTAL"])) print("\n") all_ngrams.append(ngrams) #Now merge everything all_ngrams = ct.merge_with(sum, [x for x in all_ngrams]) print("\tTOTAL NGRAMS: " + str(len(list(all_ngrams.keys())))) print("\tTOTAL WORDS: " + str(all_ngrams["TOTAL"])) #Now enforce threshold keepable = lambda x: x > n_gram_threshold all_ngrams = ct.valfilter(keepable, all_ngrams) print("\tAfter pruning:") print("\tTOTAL NGRAMS: " + str(len(list(all_ngrams.keys())))) return all_ngrams
def concatenate_tweets(tweets, dictionary, characterizer, step=10000): if not tweets.count(): return [] all_bows = [] for tweets in partition_all( step, queryset_iterator(tweets.only('text'), chunksize=step)): bows = [] for tweet in tweets: keywords = characterizer.tokenize(tweet.text) bow = dictionary.doc2bow(set(keywords), allow_update=False) bows.append(dict(bow)) all_bows.append(merge_with(sum, *bows)) return list(merge_with(sum, all_bows).items())
def test_curried_namespace(): def should_curry(func): if not callable(func) or isinstance(func, curry): return False nargs = enhanced_num_required_args(func) if nargs is None or nargs > 1: return True else: return nargs == 1 and enhanced_has_keywords(func) def curry_namespace(ns): return dict(( name, curry(f) if should_curry(f) else f, ) for name, f in ns.items() if '__' not in name) all_auto_curried = curry_namespace(vars(eth_utils)) inferred_namespace = valfilter(callable, all_auto_curried) curried_namespace = valfilter(callable, eth_utils.curried.__dict__) if inferred_namespace != curried_namespace: missing = set(inferred_namespace) - set(curried_namespace) if missing: to_insert = sorted("%s," % f for f in missing) raise AssertionError( 'There are missing functions in eth_utils.curried:\n' + '\n'.join(to_insert)) extra = set(curried_namespace) - set(inferred_namespace) if extra: raise AssertionError( 'There are extra functions in eth_utils.curried:\n' + '\n'.join(sorted(extra))) unequal = merge_with(list, inferred_namespace, curried_namespace) unequal = valfilter(lambda x: x[0] != x[1], unequal) to_curry = keyfilter(lambda x: should_curry(getattr(eth_utils, x)), unequal) if to_curry: to_curry_formatted = sorted('{0} = curry({0})'.format(f) for f in to_curry) raise AssertionError( 'There are missing functions to curry in eth_utils.curried:\n' + '\n'.join(to_curry_formatted)) elif unequal: not_to_curry_formatted = sorted(unequal) raise AssertionError( 'Missing functions NOT to curry in eth_utils.curried:\n' + '\n'.join(not_to_curry_formatted)) else: raise AssertionError("unexplained difference between %r and %r" % ( inferred_namespace, curried_namespace, ))
def getFuncsIndex(self): for elt in ast.walk(self.node): if isinstance(elt, ast.Name): indice = self.getIndice(elt) if hasattr(elt, "func"): self.funcs = merge_with(lambda x: reduce(add, x, []), self.funcs, {elt.func: [indice]}) for key in self.funcs.keys(): self.funcsIndex[key] = itemmap(lambda kv, key=key: (kv[0], (int(kv[0] in self.funcs[key]), )), self.index)
def union(self, other: Type['GeneSignature']) -> Type['GeneSignature']: """ Creates a new :class:`GeneSignature` instance which is the union of this signature and the other supplied signature. The weight associated with the genes in the intersection is the maximum of the weights in the composing signatures. :param other: The other :class:`GeneSignature`. :return: the new :class:`GeneSignature` instance. """ return self.copy(name="({} | {})".format(self.name, other.name) if self.name != other.name else self.name, gene2weight=frozendict(merge_with(max, self.gene2weight, other.gene2weight)))
def chain( cls, datasets: List[Dataset], identifier: Identifier, ) -> Dataset: """Chain a list of datasets.""" return cls.from_batch( tz.merge_with( tz.compose(list, tz.concat), *[dataset[:] for dataset in datasets], ), identifier=identifier, )
def test_class_sigs(): """Test that all ``cdef class`` extension types in ``cytoolz`` have correctly embedded the function signature as done in ``toolz``. """ import toolz # only consider items created in both `toolz` and `cytoolz` toolz_dict = valfilter(isfrommod("toolz"), toolz.__dict__) cytoolz_dict = valfilter(isfrommod("cytoolz"), cytoolz.__dict__) # only test `cdef class` extensions from `cytoolz` cytoolz_dict = valfilter( lambda x: not isinstance(x, BuiltinFunctionType), cytoolz_dict ) # full API coverage should be tested elsewhere toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict) cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict) d = merge_with(identity, toolz_dict, cytoolz_dict) for key, (toolz_func, cytoolz_func) in d.items(): if key in ["excepts", "juxt", "memoize", "flip"]: continue try: # function toolz_spec = inspect.getargspec(toolz_func) except TypeError: try: # curried or partial object toolz_spec = inspect.getargspec(toolz_func.func) except (TypeError, AttributeError): # class toolz_spec = inspect.getargspec(toolz_func.__init__) # For Cython < 0.25 toolz_sig = toolz_func.__name__ + inspect.formatargspec(*toolz_spec) doc = cytoolz_func.__doc__ # For Cython >= 0.25 toolz_sig_alt = toolz_func.__name__ + inspect.formatargspec( *toolz_spec, **{"formatvalue": lambda x: "=" + getattr(x, "__name__", repr(x))} ) doc_alt = doc.replace("Py_ssize_t ", "") if not (toolz_sig in doc or toolz_sig_alt in doc_alt): message = ( "cytoolz.%s does not have correct function signature." "\n\nExpected: %s" "\n\nDocstring in cytoolz is:\n%s" % (key, toolz_sig, cytoolz_func.__doc__) ) assert False, message
def intersection(self, other: Type['GeneSignature']) -> Type['GeneSignature']: """ Creates a new :class:`GeneSignature` instance which is the intersection of this signature and the supplied other signature. The weight associated with the genes in the intersection is the maximum of the weights in the composing signatures. :param other: The other :class:`GeneSignature`. :return: the new :class:`GeneSignature` instance. """ genes = set(self.gene2weight.keys()).intersection(set(other.gene2weight.keys())) return self.copy(name="({} & {})".format(self.name, other.name) if self.name != other.name else self.name, gene2weight=frozendict(keyfilter(lambda k: k in genes, merge_with(max, self.gene2weight, other.gene2weight))))
def test_curried_namespace(): exceptions = import_module("cytoolz.curried.exceptions") namespace = {} def should_curry(func): if not callable(func) or isinstance(func, cytoolz.curry): return False nargs = cytoolz.functoolz.num_required_args(func) if nargs is None or nargs > 1: return True return nargs == 1 and cytoolz.functoolz.has_keywords(func) def curry_namespace(ns): return { name: cytoolz.curry(f) if should_curry(f) else f for name, f in ns.items() if "__" not in name } from_cytoolz = curry_namespace(vars(cytoolz)) from_exceptions = curry_namespace(vars(exceptions)) namespace.update(cytoolz.merge(from_cytoolz, from_exceptions)) namespace = cytoolz.valfilter(callable, namespace) curried_namespace = cytoolz.valfilter(callable, cytoolz.curried.__dict__) if namespace != curried_namespace: missing = set(namespace) - set(curried_namespace) if missing: raise AssertionError( "There are missing functions in cytoolz.curried:\n %s" % " \n".join(sorted(missing))) extra = set(curried_namespace) - set(namespace) if extra: raise AssertionError( "There are extra functions in cytoolz.curried:\n %s" % " \n".join(sorted(extra))) unequal = cytoolz.merge_with(list, namespace, curried_namespace) unequal = cytoolz.valfilter(lambda x: x[0] != x[1], unequal) messages = [] for name, (orig_func, auto_func) in sorted(unequal.items()): if name in from_exceptions: messages.append( "%s should come from cytoolz.curried.exceptions" % name) elif should_curry(getattr(cytoolz, name)): messages.append("%s should be curried from cytoolz" % name) else: messages.append( "%s should come from cytoolz and NOT be curried" % name) raise AssertionError("\n".join(messages))
def from_batches( cls, batches: Sequence[Batch], identifier: Identifier = None, dataset_fmt: str = "in_memory", ) -> Dataset: """Convert a list of batches to a dataset.""" return cls.from_batch( tz.merge_with( tz.compose(list, tz.concat), *batches, ), identifier=identifier, dataset_fmt=dataset_fmt, )
def __iter__(self): self.order = [] user_ids = list(set(self.tweets.values_list(self.key, flat=True))) for user_pks in partition_all(self.step, user_ids): queryset = self.tweets.filter(user_id__in=user_pks) for key_id, tweet_set in groupby(queryset.only('text', self.key).order_by(self.key), key=self.key_func): bows = [] for tweet in tweet_set: keywords = self.characterizer.tokenize(tweet.text) bow = self.dictionary.doc2bow(set(keywords), allow_update=False) bows.append(dict(bow)) self.order.append(key_id) yield list(merge_with(sum, *bows).items())
def test_class_sigs(): """ Test that all ``cdef class`` extension types in ``cytoolz`` have correctly embedded the function signature as done in ``toolz``. """ import toolz # only consider items created in both `toolz` and `cytoolz` toolz_dict = valfilter(isfrommod('toolz'), toolz.__dict__) cytoolz_dict = valfilter(isfrommod('cytoolz'), cytoolz.__dict__) # only test `cdef class` extensions from `cytoolz` cytoolz_dict = valfilter(lambda x: not isinstance(x, BuiltinFunctionType), cytoolz_dict) # full API coverage should be tested elsewhere toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict) cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict) d = merge_with(identity, toolz_dict, cytoolz_dict) for key, (toolz_func, cytoolz_func) in d.items(): if key in ['excepts', 'juxt', 'memoize', 'flip']: continue try: # function toolz_spec = inspect.getargspec(toolz_func) except TypeError: try: # curried or partial object toolz_spec = inspect.getargspec(toolz_func.func) except (TypeError, AttributeError): # class toolz_spec = inspect.getargspec(toolz_func.__init__) # For Cython < 0.25 toolz_sig = toolz_func.__name__ + inspect.formatargspec(*toolz_spec) doc = cytoolz_func.__doc__ # For Cython >= 0.25 toolz_sig_alt = toolz_func.__name__ + inspect.formatargspec( *toolz_spec, **{'formatvalue': lambda x: '=' + getattr(x, '__name__', repr(x))} ) doc_alt = doc.replace('Py_ssize_t ', '') if not (toolz_sig in doc or toolz_sig_alt in doc_alt): message = ('cytoolz.%s does not have correct function signature.' '\n\nExpected: %s' '\n\nDocstring in cytoolz is:\n%s' % (key, toolz_sig, cytoolz_func.__doc__)) assert False, message
def collect(self): from prometheus_client.core import GaugeMetricFamily, CounterMetricFamily yield GaugeMetricFamily( "dask_scheduler_clients", "Number of clients connected.", value=len(self.server.clients), ) yield GaugeMetricFamily( "dask_scheduler_desired_workers", "Number of workers scheduler needs for task graph.", value=self.server.adaptive_target(), ) worker_states = GaugeMetricFamily( "dask_scheduler_workers", "Number of workers known by scheduler.", labels=["state"], ) worker_states.add_metric(["connected"], len(self.server.workers)) worker_states.add_metric(["saturated"], len(self.server.saturated)) worker_states.add_metric(["idle"], len(self.server.idle)) yield worker_states tasks = GaugeMetricFamily( "dask_scheduler_tasks", "Number of tasks known by scheduler.", labels=["state"], ) task_counter = merge_with( sum, (tp.states for tp in self.server.task_prefixes.values())) yield CounterMetricFamily( "dask_scheduler_tasks_forgotten", ("Total number of processed tasks no longer in memory and already " "removed from the scheduler job queue. Note task groups on the " "scheduler which have all tasks in the forgotten state are not included." ), value=task_counter.get("forgotten", 0.0), ) for state in ALL_TASK_STATES: tasks.add_metric([state], task_counter.get(state, 0.0)) yield tasks
def test_curried_namespace(): exceptions = import_module('cytoolz.curried.exceptions') namespace = {} def should_curry(func): if not callable(func) or isinstance(func, cytoolz.curry): return False nargs = cytoolz.functoolz.num_required_args(func) if nargs is None or nargs > 1: return True return nargs == 1 and cytoolz.functoolz.has_keywords(func) def curry_namespace(ns): return dict( (name, cytoolz.curry(f) if should_curry(f) else f) for name, f in ns.items() if '__' not in name ) from_cytoolz = curry_namespace(vars(cytoolz)) from_exceptions = curry_namespace(vars(exceptions)) namespace.update(cytoolz.merge(from_cytoolz, from_exceptions)) namespace = cytoolz.valfilter(callable, namespace) curried_namespace = cytoolz.valfilter(callable, cytoolz.curried.__dict__) if namespace != curried_namespace: missing = set(namespace) - set(curried_namespace) if missing: raise AssertionError('There are missing functions in cytoolz.curried:\n %s' % ' \n'.join(sorted(missing))) extra = set(curried_namespace) - set(namespace) if extra: raise AssertionError('There are extra functions in cytoolz.curried:\n %s' % ' \n'.join(sorted(extra))) unequal = cytoolz.merge_with(list, namespace, curried_namespace) unequal = cytoolz.valfilter(lambda x: x[0] != x[1], unequal) messages = [] for name, (orig_func, auto_func) in sorted(unequal.items()): if name in from_exceptions: messages.append('%s should come from cytoolz.curried.exceptions' % name) elif should_curry(getattr(cytoolz, name)): messages.append('%s should be curried from cytoolz' % name) else: messages.append('%s should come from cytoolz and NOT be curried' % name) raise AssertionError('\n'.join(messages))
def __iter__(self): self.order = [] user_ids = list(set(self.tweets.values_list(self.key, flat=True))) for user_pks in partition_all(self.step, user_ids): queryset = self.tweets.filter(user_id__in=user_pks) for key_id, tweet_set in groupby(queryset.only( 'text', self.key).order_by(self.key), key=self.key_func): bows = [] for tweet in tweet_set: keywords = self.characterizer.tokenize(tweet.text) bow = self.dictionary.doc2bow(set(keywords), allow_update=False) bows.append(dict(bow)) self.order.append(key_id) yield list(merge_with(sum, *bows).items())
def create_slice(args): # Unpack args dataset, slice_membership, slice_batches, i, batch_size, slice_cache_hash = args # Create a new empty slice sl = Slice.from_dict({}) # Create a Slice "copy" of the Dataset sl.__dict__.update(dataset.__dict__) sl._identifier = None # Filter sl = sl.filter( lambda example, idx: bool(slice_membership[idx, i]), with_indices=True, input_columns=["index"], batch_size=batch_size, cache_file_name=str( dataset.logdir / ("cache-" + str(abs(slice_cache_hash)) + "-filter.arrow") ), ) slice_batch = tz.merge_with(tz.compose(list, tz.concat), slice_batches) # FIXME(karan): interaction tape history is wrong here, esp with augmenation/attacks # Map if len(sl): sl = sl.map( lambda batch, indices: tz.valmap( lambda v: v[indices[0] : indices[0] + batch_size], slice_batch ), batched=True, batch_size=batch_size, with_indices=True, remove_columns=sl.column_names, cache_file_name=str( dataset.logdir / ("cache-" + str(abs(slice_cache_hash)) + ".arrow") ), ) return sl
def test_docstrings_uptodate(): import toolz differ = difflib.Differ() # only consider items created in both `toolz` and `cytoolz` toolz_dict = valfilter(isfrommod("toolz"), toolz.__dict__) cytoolz_dict = valfilter(isfrommod("cytoolz"), cytoolz.__dict__) # only test functions that have docstrings defined in `toolz` toolz_dict = valfilter(lambda x: getattr(x, "__doc__", ""), toolz_dict) # full API coverage should be tested elsewhere toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict) cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict) d = merge_with(identity, toolz_dict, cytoolz_dict) for key, (toolz_func, cytoolz_func) in d.items(): # only check if the new doctstring *contains* the expected docstring toolz_doc = convertdoc(toolz_func) cytoolz_doc = cytoolz_func.__doc__ if toolz_doc not in cytoolz_doc: diff = list( differ.compare(toolz_doc.splitlines(), cytoolz_doc.splitlines())) fulldiff = list(diff) # remove additional lines at the beginning while diff and diff[0].startswith("+"): diff.pop(0) # remove additional lines at the end while diff and diff[-1].startswith("+"): diff.pop() def checkbad(line): return line.startswith("+") and not ( "# doctest: +SKIP" in line and key in skipped_doctests) if any(map(checkbad, diff)): assert False, "Error: cytoolz.%s has a bad docstring:\n%s\n" % ( key, "\n".join(fulldiff), )
def test_class_sigs(): """ Test that all ``cdef class`` extension types in ``cytoolz`` have correctly embedded the function signature as done in ``toolz``. """ import toolz # only consider items created in both `toolz` and `cytoolz` toolz_dict = valfilter(isfrommod('toolz'), toolz.__dict__) cytoolz_dict = valfilter(isfrommod('cytoolz'), cytoolz.__dict__) # only test `cdef class` extensions from `cytoolz` cytoolz_dict = valfilter(lambda x: not isinstance(x, BuiltinFunctionType), cytoolz_dict) # full API coverage should be tested elsewhere toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict) cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict) d = merge_with(identity, toolz_dict, cytoolz_dict) for key, (toolz_func, cytoolz_func) in d.items(): if key in ['excepts', 'juxt']: continue try: # function toolz_spec = inspect.getargspec(toolz_func) except TypeError: try: # curried or partial object toolz_spec = inspect.getargspec(toolz_func.func) except (TypeError, AttributeError): # class toolz_spec = inspect.getargspec(toolz_func.__init__) toolz_sig = toolz_func.__name__ + inspect.formatargspec(*toolz_spec) if toolz_sig not in cytoolz_func.__doc__: message = ('cytoolz.%s does not have correct function signature.' '\n\nExpected: %s' '\n\nDocstring in cytoolz is:\n%s' % (key, toolz_sig, cytoolz_func.__doc__)) assert False, message
def test_docstrings_uptodate(): import toolz differ = difflib.Differ() # only consider items created in both `toolz` and `cytoolz` toolz_dict = valfilter(isfrommod('toolz'), toolz.__dict__) cytoolz_dict = valfilter(isfrommod('cytoolz'), cytoolz.__dict__) # only test functions that have docstrings defined in `toolz` toolz_dict = valfilter(lambda x: getattr(x, '__doc__', ''), toolz_dict) # full API coverage should be tested elsewhere toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict) cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict) d = merge_with(identity, toolz_dict, cytoolz_dict) for key, (toolz_func, cytoolz_func) in d.items(): # only check if the new doctstring *contains* the expected docstring toolz_doc = convertdoc(toolz_func) cytoolz_doc = cytoolz_func.__doc__ if toolz_doc not in cytoolz_doc: diff = list(differ.compare(toolz_doc.splitlines(), cytoolz_doc.splitlines())) fulldiff = list(diff) # remove additional lines at the beginning while diff and diff[0].startswith('+'): diff.pop(0) # remove additional lines at the end while diff and diff[-1].startswith('+'): diff.pop() def checkbad(line): return (line.startswith('+') and not ('# doctest: +SKIP' in line and key in skipped_doctests)) if any(map(checkbad, diff)): assert False, 'Error: cytoolz.%s has a bad docstring:\n%s\n' % ( key, '\n'.join(fulldiff))
def get_vocab(df, phraser=None, stop=None, nlp=None, column="Text", workers=1): """ Gets vocab :param df: :return: """ chunksize = int(len(df) / workers) pool_instance = mp.Pool(processes=workers, maxtasksperchild=1) vocab = pool_instance.map(partial(process_vocab, phraser=phraser, stop=stop, nlp=nlp), ct.partition(chunksize, df.loc[:, column].values), chunksize=1) pool_instance.close() pool_instance.join() vocab = ct.merge_with(sum, vocab) return vocab
def test_class_sigs(): """ Test that all ``cdef class`` extension types in ``cytoolz`` have correctly embedded the function signature as done in ``toolz``. """ import toolz # only consider items created in both `toolz` and `cytoolz` toolz_dict = valfilter(isfrommod('toolz'), toolz.__dict__) cytoolz_dict = valfilter(isfrommod('cytoolz'), cytoolz.__dict__) # only test `cdef class` extensions from `cytoolz` cytoolz_dict = valfilter(lambda x: not isinstance(x, BuiltinFunctionType), cytoolz_dict) # full API coverage should be tested elsewhere toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict) cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict) d = merge_with(identity, toolz_dict, cytoolz_dict) for key, (toolz_func, cytoolz_func) in d.items(): try: # function toolz_spec = inspect.getargspec(toolz_func) except TypeError: try: # curried or partial object toolz_spec = inspect.getargspec(toolz_func.func) except (TypeError, AttributeError): # class toolz_spec = inspect.getargspec(toolz_func.__init__) toolz_sig = toolz_func.__name__ + inspect.formatargspec(*toolz_spec) if toolz_sig not in cytoolz_func.__doc__: message = ('cytoolz.%s does not have correct function signature.' '\n\nExpected: %s' '\n\nDocstring in cytoolz is:\n%s' % (key, toolz_sig, cytoolz_func.__doc__)) assert False, message
def from_batches(cls, batches: List[Batch]): """Create an AbstractDataset from a list of batches.""" return cls.from_batch( tz.merge_with(tz.compose(list, tz.concat), *batches), )
def deep_merge(*dicts): return merge_with(merge_if_dicts, *dicts)
def merge_if_dicts(values: Sequence[Dict[Any, Any]]) -> Any: if all(isinstance(item, Mapping) for item in values): return merge_with(merge_if_dicts, *values) else: return values[-1]
def deep_merge(*dicts: Dict[Any, Any]) -> Dict[Any, Any]: return merge_with(merge_if_dicts, *dicts)
def __init__( self, *args, column_names: List[str] = None, info: DatasetInfo = None, split: Optional[NamedSplit] = None, ): # Data is a dictionary of lists self._data = {} # Single argument if len(args) == 1: assert column_names is None, "Don't pass in column_names." # The data is passed in data = args[0] # `data` is a dictionary if isinstance(data, dict) and len(data): # Assert all columns are the same length self._assert_columns_all_equal_length(data) self._data = data # `data` is a list elif isinstance(data, list) and len(data): # Transpose the list of dicts to a dict of lists i.e. a batch data = tz.merge_with(list, *data) # Assert all columns are the same length self._assert_columns_all_equal_length(data) self._data = data # `data` is a datasets.Dataset elif isinstance(data, datasets.Dataset): self._data = data[:] info, split = data.info, data.split # No argument elif len(args) == 0: # Use column_names to setup the data dictionary if column_names: self._data = {k: [] for k in column_names} # Setup the DatasetInfo info = info.copy() if info is not None else DatasetInfo() AbstractDataset.__init__(self, info=info, split=split) # Create attributes for all columns and visible columns self.all_columns = list(self._data.keys()) self.visible_columns = None # Create attributes for visible rows self.visible_rows = None # Initialization self._initialize_state() logger.info( f"Created `InMemoryDataset` with {len(self)} rows and " f"{len(self.column_names)} columns." )
def evaluate( self, dataset: DataPanel, input_columns: List[str], output_columns: List[str], batch_size: int = 32, metrics: List[str] = None, coerce_fn: Callable = None, ): # TODO(karan): generalize to TF2 # Reset the dataset format dataset.reset_format() dataset.set_format(columns=input_columns + output_columns) # TODO(karan): check that the DataPanel conforms to the task definition # TODO(karan): figure out how the output_columns will be used by the metrics pass predictions = [] targets = [] # Loop and apply the prediction function # TODO(karan): not using .map() here in order to get more fine-grained # control over devices for idx in range(0, len(dataset), batch_size): # Create the batch batch = dataset[idx:idx + batch_size] # Predict on the batch prediction_dict = self.predict_batch(batch=batch, input_columns=input_columns) # Coerce the predictions if coerce_fn: prediction_dict = coerce_fn(prediction_dict) # Grab the raw target key/values target_dict = tz.keyfilter(lambda k: k in output_columns, batch) # TODO(karan): general version for non-classification problems # TODO(karan): move this to the right device if self.is_classifier: target_dict = tz.valmap(lambda v: torch.tensor(v), target_dict) # TODO(karan): incremental metric computation here # Append the predictions and targets predictions.append(prediction_dict) targets.append(target_dict) # Consolidate the predictions and targets if self.is_classifier: # TODO(karan): Need to store predictions and outputs from the model predictions = tz.merge_with(lambda v: torch.cat(v).to("cpu"), *predictions) targets = tz.merge_with(lambda v: torch.cat(v).to("cpu"), *targets) else: predictions = tz.merge_with( lambda x: list(itertools.chain.from_iterable(x)), *predictions) targets = tz.merge_with( lambda x: list(itertools.chain.from_iterable(x)), *targets) # Compute the metrics # TODO(karan): generalize this code to support metric computation for any task # Assumes classification, so the output_columns contains a single key for the # label if self.is_classifier: assert len( output_columns) == 1 # , "Only supports classification." num_classes = self.task.output_schema.features[list( self.task.output_schema.columns)[0]].num_classes labels = targets[list(targets.keys())[0]] if metrics is None: if self.task is None: raise ValueError( "Must specify metrics if model not associated with task") metrics = self.task.metrics pred = predictions["pred"].to(self.device) target = labels.to(self.device) evaluation_dict = { metric: compute_metric(metric, pred, target, num_classes) for metric in metrics } # Reset the data format dataset.reset_format() return evaluation_dict
def merge_with(func, d, *dicts, **kwargs): return cytoolz.merge_with(func, d, *dicts, **kwargs)
def merge(left, right): return cytoolz.merge_with(sum, left, right)
def merge_if_dicts(values): if all(isinstance(item, Mapping) for item in values): return merge_with(merge_if_dicts, *values) else: return values[-1]