def test_class_sigs(): """Test that all ``cdef class`` extension types in ``cytoolz`` have correctly embedded the function signature as done in ``toolz``. """ import toolz # only consider items created in both `toolz` and `cytoolz` toolz_dict = valfilter(isfrommod("toolz"), toolz.__dict__) cytoolz_dict = valfilter(isfrommod("cytoolz"), cytoolz.__dict__) # only test `cdef class` extensions from `cytoolz` cytoolz_dict = valfilter( lambda x: not isinstance(x, BuiltinFunctionType), cytoolz_dict ) # full API coverage should be tested elsewhere toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict) cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict) d = merge_with(identity, toolz_dict, cytoolz_dict) for key, (toolz_func, cytoolz_func) in d.items(): if key in ["excepts", "juxt", "memoize", "flip"]: continue try: # function toolz_spec = inspect.getargspec(toolz_func) except TypeError: try: # curried or partial object toolz_spec = inspect.getargspec(toolz_func.func) except (TypeError, AttributeError): # class toolz_spec = inspect.getargspec(toolz_func.__init__) # For Cython < 0.25 toolz_sig = toolz_func.__name__ + inspect.formatargspec(*toolz_spec) doc = cytoolz_func.__doc__ # For Cython >= 0.25 toolz_sig_alt = toolz_func.__name__ + inspect.formatargspec( *toolz_spec, **{"formatvalue": lambda x: "=" + getattr(x, "__name__", repr(x))} ) doc_alt = doc.replace("Py_ssize_t ", "") if not (toolz_sig in doc or toolz_sig_alt in doc_alt): message = ( "cytoolz.%s does not have correct function signature." "\n\nExpected: %s" "\n\nDocstring in cytoolz is:\n%s" % (key, toolz_sig, cytoolz_func.__doc__) ) assert False, message
def test_sig_at_beginning(): """Test that the function signature is at the beginning of the docstring and is followed by exactly one blank line. """ cytoolz_dict = valfilter(isfrommod("cytoolz"), cytoolz.__dict__) cytoolz_dict = keyfilter(lambda x: x not in skip_sigs, cytoolz_dict) for key, val in cytoolz_dict.items(): doclines = val.__doc__.splitlines() assert len(doclines) > 2, "cytoolz.%s docstring too short:\n\n%s" % ( key, val.__doc__, ) sig = "%s(" % aliases.get(key, key) assert sig in doclines[0], ( "cytoolz.%s docstring missing signature at beginning:\n\n%s" % (key, val.__doc__) ) assert not doclines[1], ( "cytoolz.%s docstring missing blank line after signature:\n\n%s" % (key, val.__doc__) ) assert doclines[2], ( "cytoolz.%s docstring too many blank lines after signature:\n\n%s" % (key, val.__doc__) )
def wordclouds(self, df, stage=0, features="frequency", name=None, stopwords=None): """ Build wordclouds and choose what features to use :param df: :param stage: :param features: :param name: :param stopwords: :return: """ # If only using frequency, use a pure Python method if features == "frequency": vocab = self._get_wordcloud_frequency_vocab(df, stage) # If using TF-IDF, use pre-fit vectorizer else: # features == "tfidf": vocab = self._get_wordcloud_tfidf_vocab(df) # Remove defined stopwords if stopwords is not None: vocab = ct.keyfilter(lambda x: x not in stopwords, vocab) # Pass pre-made frequencies to wordcloud, allowing for TF-IDF self.wordcloud.generate_from_frequencies(frequencies=vocab) # Prepare plot with title, etc self._plot_wordcloud(stage, name)
def head(self, n: int = 5) -> Type['GeneSignature']: """ Returns a gene signature with only the top n targets. """ assert n >= 1, "n must be greater than or equal to one." genes = self.genes[0:n] # Genes are sorted in ascending order according to weight. return self.copy(gene2weight=keyfilter(lambda k: k in genes, self.gene2weight))
def test_class_sigs(): """ Test that all ``cdef class`` extension types in ``cytoolz`` have correctly embedded the function signature as done in ``toolz``. """ import toolz # only consider items created in both `toolz` and `cytoolz` toolz_dict = valfilter(isfrommod('toolz'), toolz.__dict__) cytoolz_dict = valfilter(isfrommod('cytoolz'), cytoolz.__dict__) # only test `cdef class` extensions from `cytoolz` cytoolz_dict = valfilter(lambda x: not isinstance(x, BuiltinFunctionType), cytoolz_dict) # full API coverage should be tested elsewhere toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict) cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict) d = merge_with(identity, toolz_dict, cytoolz_dict) for key, (toolz_func, cytoolz_func) in d.items(): if key in ['excepts', 'juxt', 'memoize', 'flip']: continue try: # function toolz_spec = inspect.getargspec(toolz_func) except TypeError: try: # curried or partial object toolz_spec = inspect.getargspec(toolz_func.func) except (TypeError, AttributeError): # class toolz_spec = inspect.getargspec(toolz_func.__init__) # For Cython < 0.25 toolz_sig = toolz_func.__name__ + inspect.formatargspec(*toolz_spec) doc = cytoolz_func.__doc__ # For Cython >= 0.25 toolz_sig_alt = toolz_func.__name__ + inspect.formatargspec( *toolz_spec, **{'formatvalue': lambda x: '=' + getattr(x, '__name__', repr(x))} ) doc_alt = doc.replace('Py_ssize_t ', '') if not (toolz_sig in doc or toolz_sig_alt in doc_alt): message = ('cytoolz.%s does not have correct function signature.' '\n\nExpected: %s' '\n\nDocstring in cytoolz is:\n%s' % (key, toolz_sig, cytoolz_func.__doc__)) assert False, message
def clean_dialect(dialect): """ Make a csv dialect apprpriate for pandas.read_csv """ dialect = keyfilter(read_csv_kwargs.__contains__, dialect) # handle windows if dialect['lineterminator'] == '\r\n': dialect['lineterminator'] = None return dialect
def test_docstrings_uptodate(): import toolz differ = difflib.Differ() # only consider items created in both `toolz` and `cytoolz` toolz_dict = valfilter(isfrommod("toolz"), toolz.__dict__) cytoolz_dict = valfilter(isfrommod("cytoolz"), cytoolz.__dict__) # only test functions that have docstrings defined in `toolz` toolz_dict = valfilter(lambda x: getattr(x, "__doc__", ""), toolz_dict) # full API coverage should be tested elsewhere toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict) cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict) d = merge_with(identity, toolz_dict, cytoolz_dict) for key, (toolz_func, cytoolz_func) in d.items(): # only check if the new doctstring *contains* the expected docstring toolz_doc = convertdoc(toolz_func) cytoolz_doc = cytoolz_func.__doc__ if toolz_doc not in cytoolz_doc: diff = list( differ.compare(toolz_doc.splitlines(), cytoolz_doc.splitlines())) fulldiff = list(diff) # remove additional lines at the beginning while diff and diff[0].startswith("+"): diff.pop(0) # remove additional lines at the end while diff and diff[-1].startswith("+"): diff.pop() def checkbad(line): return line.startswith("+") and not ( "# doctest: +SKIP" in line and key in skipped_doctests) if any(map(checkbad, diff)): assert False, "Error: cytoolz.%s has a bad docstring:\n%s\n" % ( key, "\n".join(fulldiff), )
def clean_id_strs(tweet): if type(tweet) == dict: tweet = keyfilter(lambda x: not x.endswith('id_str'), tweet) for key in tweet.keys(): tweet[key] = clean_id_strs(tweet[key]) elif type(tweet) == list: tweet = map(clean_id_strs, tweet) return tweet
def test_curried_namespace(): def should_curry(func): if not callable(func) or isinstance(func, curry): return False nargs = enhanced_num_required_args(func) if nargs is None or nargs > 1: return True else: return nargs == 1 and enhanced_has_keywords(func) def curry_namespace(ns): return dict(( name, curry(f) if should_curry(f) else f, ) for name, f in ns.items() if '__' not in name) all_auto_curried = curry_namespace(vars(eth_utils)) inferred_namespace = valfilter(callable, all_auto_curried) curried_namespace = valfilter(callable, eth_utils.curried.__dict__) if inferred_namespace != curried_namespace: missing = set(inferred_namespace) - set(curried_namespace) if missing: to_insert = sorted("%s," % f for f in missing) raise AssertionError( 'There are missing functions in eth_utils.curried:\n' + '\n'.join(to_insert)) extra = set(curried_namespace) - set(inferred_namespace) if extra: raise AssertionError( 'There are extra functions in eth_utils.curried:\n' + '\n'.join(sorted(extra))) unequal = merge_with(list, inferred_namespace, curried_namespace) unequal = valfilter(lambda x: x[0] != x[1], unequal) to_curry = keyfilter(lambda x: should_curry(getattr(eth_utils, x)), unequal) if to_curry: to_curry_formatted = sorted('{0} = curry({0})'.format(f) for f in to_curry) raise AssertionError( 'There are missing functions to curry in eth_utils.curried:\n' + '\n'.join(to_curry_formatted)) elif unequal: not_to_curry_formatted = sorted(unequal) raise AssertionError( 'Missing functions NOT to curry in eth_utils.curried:\n' + '\n'.join(not_to_curry_formatted)) else: raise AssertionError("unexplained difference between %r and %r" % ( inferred_namespace, curried_namespace, ))
def test_docstrings_uptodate(): import toolz differ = difflib.Differ() # only consider items created in both `toolz` and `cytoolz` toolz_dict = valfilter(isfrommod('toolz'), toolz.__dict__) cytoolz_dict = valfilter(isfrommod('cytoolz'), cytoolz.__dict__) # only test functions that have docstrings defined in `toolz` toolz_dict = valfilter(lambda x: getattr(x, '__doc__', ''), toolz_dict) # full API coverage should be tested elsewhere toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict) cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict) d = merge_with(identity, toolz_dict, cytoolz_dict) for key, (toolz_func, cytoolz_func) in d.items(): # only check if the new doctstring *contains* the expected docstring toolz_doc = convertdoc(toolz_func) cytoolz_doc = cytoolz_func.__doc__ if toolz_doc not in cytoolz_doc: diff = list(differ.compare(toolz_doc.splitlines(), cytoolz_doc.splitlines())) fulldiff = list(diff) # remove additional lines at the beginning while diff and diff[0].startswith('+'): diff.pop(0) # remove additional lines at the end while diff and diff[-1].startswith('+'): diff.pop() def checkbad(line): return (line.startswith('+') and not ('# doctest: +SKIP' in line and key in skipped_doctests)) if any(map(checkbad, diff)): assert False, 'Error: cytoolz.%s has a bad docstring:\n%s\n' % ( key, '\n'.join(fulldiff))
def _get_reader(self, header=None, keep_default_na=False, na_values=na_values, chunksize=None, **kwargs): kwargs.setdefault('skiprows', int(bool(self.header))) dialect = merge(keyfilter(read_csv_kwargs.__contains__, self.dialect), kwargs) # handle windows if dialect['lineterminator'] == '\r\n': dialect['lineterminator'] = None return partial(pd.read_csv, chunksize=chunksize, na_values=na_values, keep_default_na=keep_default_na, encoding=self.encoding, header=header, **dialect)
def test_class_sigs(): """ Test that all ``cdef class`` extension types in ``cytoolz`` have correctly embedded the function signature as done in ``toolz``. """ import toolz # only consider items created in both `toolz` and `cytoolz` toolz_dict = valfilter(isfrommod('toolz'), toolz.__dict__) cytoolz_dict = valfilter(isfrommod('cytoolz'), cytoolz.__dict__) # only test `cdef class` extensions from `cytoolz` cytoolz_dict = valfilter(lambda x: not isinstance(x, BuiltinFunctionType), cytoolz_dict) # full API coverage should be tested elsewhere toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict) cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict) d = merge_with(identity, toolz_dict, cytoolz_dict) for key, (toolz_func, cytoolz_func) in d.items(): if key in ['excepts', 'juxt']: continue try: # function toolz_spec = inspect.getargspec(toolz_func) except TypeError: try: # curried or partial object toolz_spec = inspect.getargspec(toolz_func.func) except (TypeError, AttributeError): # class toolz_spec = inspect.getargspec(toolz_func.__init__) toolz_sig = toolz_func.__name__ + inspect.formatargspec(*toolz_spec) if toolz_sig not in cytoolz_func.__doc__: message = ('cytoolz.%s does not have correct function signature.' '\n\nExpected: %s' '\n\nDocstring in cytoolz is:\n%s' % (key, toolz_sig, cytoolz_func.__doc__)) assert False, message
def treemap(tree, mapfn, key=None, path=()): res = tree if isinstance(tree, dict): res = keyfilter( identity, itemmap( lambda item: treemap(item[1], mapfn, item[0], path + (item[0], )), tree), ) elif isinstance(tree, list): res = list(map(lambda t: treemap(t, mapfn, None, path), tree)) return xform(key, res, mapfn, path)
def intersection(self, other: Type['GeneSignature']) -> Type['GeneSignature']: """ Creates a new :class:`GeneSignature` instance which is the intersection of this signature and the supplied other signature. The weight associated with the genes in the intersection is the maximum of the weights in the composing signatures. :param other: The other :class:`GeneSignature`. :return: the new :class:`GeneSignature` instance. """ genes = set(self.gene2weight.keys()).intersection(set(other.gene2weight.keys())) return self.copy(name="({} & {})".format(self.name, other.name) if self.name != other.name else self.name, gene2weight=frozendict(keyfilter(lambda k: k in genes, merge_with(max, self.gene2weight, other.gene2weight))))
def _extend(self, rows): mode = 'ab' if PY2 else 'a' newline = dict() if PY2 else dict(newline='') dialect = keyfilter(to_csv_kwargs.__contains__, self.dialect) should_write_newline = self.last_char() != os.linesep with csvopen(self, mode=mode, **newline) as f: # we have data in the file, append a newline if should_write_newline: f.write(os.linesep) for df in map(partial(bz.into, pd.DataFrame), partition_all(self.chunksize, iter(rows))): df.to_csv(f, index=False, header=None, encoding=self.encoding, **dialect)
def test_class_sigs(): """ Test that all ``cdef class`` extension types in ``cytoolz`` have correctly embedded the function signature as done in ``toolz``. """ import toolz # only consider items created in both `toolz` and `cytoolz` toolz_dict = valfilter(isfrommod('toolz'), toolz.__dict__) cytoolz_dict = valfilter(isfrommod('cytoolz'), cytoolz.__dict__) # only test `cdef class` extensions from `cytoolz` cytoolz_dict = valfilter(lambda x: not isinstance(x, BuiltinFunctionType), cytoolz_dict) # full API coverage should be tested elsewhere toolz_dict = keyfilter(lambda x: x in cytoolz_dict, toolz_dict) cytoolz_dict = keyfilter(lambda x: x in toolz_dict, cytoolz_dict) d = merge_with(identity, toolz_dict, cytoolz_dict) for key, (toolz_func, cytoolz_func) in d.items(): try: # function toolz_spec = inspect.getargspec(toolz_func) except TypeError: try: # curried or partial object toolz_spec = inspect.getargspec(toolz_func.func) except (TypeError, AttributeError): # class toolz_spec = inspect.getargspec(toolz_func.__init__) toolz_sig = toolz_func.__name__ + inspect.formatargspec(*toolz_spec) if toolz_sig not in cytoolz_func.__doc__: message = ('cytoolz.%s does not have correct function signature.' '\n\nExpected: %s' '\n\nDocstring in cytoolz is:\n%s' % (key, toolz_sig, cytoolz_func.__doc__)) assert False, message
def run(cfg: HyperParamCfg): logger = hyperparam_logger train_cfg, params = make_dict( toolz.keyfilter( lambda x: x in TrainConfig.__annotations__.keys(), cfg._asdict() ), {}, ) solver = cfg.solver_class(params) while True: assigement = solver.new_assignment() concrete_params = assigement.params folder = os.path.normpath( os.path.abspath(os.path.join(cfg.work_dir, assigement.name)) ) logger.info(f"Starting {assigement.name}") cfg_path = os.path.join(folder, "config.yml") os.makedirs(folder, exist_ok=False) concrete_cfg = subs_dict(train_cfg, concrete_params) concrete_cfg['logdir'] = folder concrete_cfg = subs_dict(TrainConfig.schema(concrete_cfg), {}) with open(cfg_path, "w") as f: yaml.safe_dump({ 'train': concrete_cfg, 'version': "v0.1", }, stream=f, default_flow_style=False) result = train.run_args( argparse.Namespace( config=cfg_path, logdir=None, name=assigement.name, ) ) logger.info(f"Got results:\n{result.describe().to_string()}\n{result}") obs = Observation( metric=float(np.mean(result['identity'])), metric_std=float(np.std(result['identity'])), metadata={ c: float(np.mean(series)) for c, series in result.iteritems() } ) solver.report(assigement, obs)
def reader(self, header=None, keep_default_na=False, na_values=na_values, chunksize=None, **kwargs): kwargs.setdefault('skiprows', int(bool(self.header))) dialect = merge(keyfilter(read_csv_kwargs.__contains__, self.dialect), kwargs) filename, ext = os.path.splitext(self.path) ext = ext.lstrip('.') # handle windows if dialect['lineterminator'] == '\r\n': dialect['lineterminator'] = None reader = pd.read_csv(self.path, compression={'gz': 'gzip', 'bz2': 'bz2'}.get(ext), chunksize=chunksize, na_values=na_values, keep_default_na=keep_default_na, encoding=self.encoding, header=header, **dialect) return reader
def apply( self, batch: DataPanel, columns: List[str], skeleton_batches: List[DataPanel], slice_membership: np.ndarray, *args, **kwargs, ) -> Tuple[List[DataPanel], np.ndarray]: # Group the batch into inputs and output batch_inputs = tz.keyfilter(lambda k: k in columns[:-1], batch) batch_inputs = [ OrderedDict(zip(batch_inputs, t)) for t in zip(*batch_inputs.values()) ] batch_output = [int(e) for e in batch[columns[-1]]] # Create a fake dataset for textattack fake_dataset = list(zip(batch_inputs, batch_output)) # Attack the dataset outputs = list(self.attack.attack_dataset(fake_dataset)) for i, output in enumerate(outputs): # Check if the goal succeeded if output.perturbed_result.goal_status == 0: # If success, fill out the skeleton batch for ( key, val, ) in output.perturbed_result.attacked_text._text_input.items(): # TODO(karan): support num_attacked_texts > 1 skeleton_batches[0][key][i] = val # # Fill the perturbed output: *this was incorrect, removing this # statement* # # TODO(karan): delete this snippet # skeleton_batches[0][columns[-1]][i] = output.perturbed_result.output else: # Unable to attack the example: set its slice membership to zero slice_membership[i, 0] = 0 return skeleton_batches, slice_membership
def _extend(self, rows): mode = 'ab' if PY2 else 'a' newline = dict() if PY2 else dict(newline='') dialect = keyfilter(to_csv_kwargs.__contains__, self.dialect) should_write_newline = self.last_char() != os.linesep f = self.open(self.path, mode, **newline) try: # we have data in the file, append a newline if should_write_newline: f.write(os.linesep) for df in map(partial(bz.into, pd.DataFrame), partition_all(self.chunksize, iter(rows))): df.to_csv(f, index=False, header=None, **dialect) finally: try: f.close() except AttributeError: pass
def test_sig_at_beginning(): """ Test that the function signature is at the beginning of the docstring and is followed by exactly one blank line. """ cytoolz_dict = valfilter(isfrommod('cytoolz'), cytoolz.__dict__) cytoolz_dict = keyfilter(lambda x: x not in skip_sigs, cytoolz_dict) for key, val in cytoolz_dict.items(): doclines = val.__doc__.splitlines() assert len(doclines) > 2, ( 'cytoolz.%s docstring too short:\n\n%s' % (key, val.__doc__)) sig = '%s(' % aliases.get(key, key) assert sig in doclines[0], ( 'cytoolz.%s docstring missing signature at beginning:\n\n%s' % (key, val.__doc__)) assert not doclines[1], ( 'cytoolz.%s docstring missing blank line after signature:\n\n%s' % (key, val.__doc__)) assert doclines[2], ( 'cytoolz.%s docstring too many blank lines after signature:\n\n%s' % (key, val.__doc__))
def select_columns(self, columns: List[str]) -> Batch: """Select a subset of columns.""" for col in columns: assert col in self._data return tz.keyfilter(lambda k: k in columns, self._data)
def __init__(self, path, mode='rt', schema=None, columns=None, types=None, typehints=None, dialect=None, header=None, open=open, nrows_discovery=50, chunksize=1024, encoding=sys.getdefaultencoding(), **kwargs): if 'r' in mode and not os.path.isfile(path): raise ValueError('CSV file "%s" does not exist' % path) if schema is None and 'w' in mode: raise ValueError('Please specify schema for writable CSV file') self.path = path self.mode = mode self.open = {'gz': gzip.open, 'bz2': bz2.BZ2File}.get(ext(path), open) self.header = header self._abspath = os.path.abspath(path) self.chunksize = chunksize self.encoding = encoding sample = get_sample(self) self.dialect = dialect = discover_dialect(sample, dialect, **kwargs) if header is None: header = has_header(sample, encoding=encoding) elif isinstance(header, int): dialect['header'] = header header = True reader_dialect = keyfilter(read_csv_kwargs.__contains__, dialect) if not schema and 'w' not in mode: if not types: data = self._reader(skiprows=1 if header else 0, nrows=nrows_discovery, as_recarray=True, index_col=False, header=0 if header else None, **reader_dialect).tolist() types = discover(data) rowtype = types.subshape[0] if isinstance(rowtype[0], Tuple): types = types.subshape[0][0].dshapes types = [unpack(t) for t in types] types = [string if t == null else t for t in types] types = [t if isinstance(t, Option) or t == string else Option(t) for t in types] elif (isinstance(rowtype[0], Fixed) and isinstance(rowtype[1], CType)): types = int(rowtype[0]) * [rowtype[1]] else: raise ValueError("Could not discover schema from data.\n" "Please specify schema.") if not columns: if header: columns = first(self._reader(skiprows=0, nrows=1, header=None, **reader_dialect ).itertuples(index=False)) else: columns = ['_%d' % i for i in range(len(types))] if typehints: types = [typehints.get(c, t) for c, t in zip(columns, types)] schema = dshape(Record(list(zip(columns, types)))) self._schema = schema self.header = header
def keyfilter(self, predicate): return fdict(cytoolz.keyfilter(predicate, self))
def evaluate( self, dataset: DataPanel, input_columns: List[str], output_columns: List[str], batch_size: int = 32, metrics: List[str] = None, coerce_fn: Callable = None, ): # TODO(karan): generalize to TF2 # Reset the dataset format dataset.reset_format() dataset.set_format(columns=input_columns + output_columns) # TODO(karan): check that the DataPanel conforms to the task definition # TODO(karan): figure out how the output_columns will be used by the metrics pass predictions = [] targets = [] # Loop and apply the prediction function # TODO(karan): not using .map() here in order to get more fine-grained # control over devices for idx in range(0, len(dataset), batch_size): # Create the batch batch = dataset[idx:idx + batch_size] # Predict on the batch prediction_dict = self.predict_batch(batch=batch, input_columns=input_columns) # Coerce the predictions if coerce_fn: prediction_dict = coerce_fn(prediction_dict) # Grab the raw target key/values target_dict = tz.keyfilter(lambda k: k in output_columns, batch) # TODO(karan): general version for non-classification problems # TODO(karan): move this to the right device if self.is_classifier: target_dict = tz.valmap(lambda v: torch.tensor(v), target_dict) # TODO(karan): incremental metric computation here # Append the predictions and targets predictions.append(prediction_dict) targets.append(target_dict) # Consolidate the predictions and targets if self.is_classifier: # TODO(karan): Need to store predictions and outputs from the model predictions = tz.merge_with(lambda v: torch.cat(v).to("cpu"), *predictions) targets = tz.merge_with(lambda v: torch.cat(v).to("cpu"), *targets) else: predictions = tz.merge_with( lambda x: list(itertools.chain.from_iterable(x)), *predictions) targets = tz.merge_with( lambda x: list(itertools.chain.from_iterable(x)), *targets) # Compute the metrics # TODO(karan): generalize this code to support metric computation for any task # Assumes classification, so the output_columns contains a single key for the # label if self.is_classifier: assert len( output_columns) == 1 # , "Only supports classification." num_classes = self.task.output_schema.features[list( self.task.output_schema.columns)[0]].num_classes labels = targets[list(targets.keys())[0]] if metrics is None: if self.task is None: raise ValueError( "Must specify metrics if model not associated with task") metrics = self.task.metrics pred = predictions["pred"].to(self.device) target = labels.to(self.device) evaluation_dict = { metric: compute_metric(metric, pred, target, num_classes) for metric in metrics } # Reset the data format dataset.reset_format() return evaluation_dict
def re_select(self, patten): patten = re.compile(patten) return list(keyfilter(patten.fullmatch, self._modules).values())
def parse_gs_line(file_line): data = json.loads(file_line) return keyfilter(lambda k: k in ["text", "docId"], data)
def omit(blacklist, d): return keyfilter(lambda k: k not in blacklist, d)
def uncached_example(cls, example: Dict, copy=True) -> Dict: """Return example with the "cache" and "slices" columns removed.""" return tz.keyfilter( lambda k: k not in ["cache", "slices"], deepcopy(example) if copy else example, )
def uncached_batch(cls, batch: Batch, copy=True) -> Batch: """Return batch with the "cache" and "slices" columns removed.""" return tz.keyfilter(lambda k: k not in ["cache", "slices"], deepcopy(batch) if copy else batch)
def _intersection_impl(self, other): genes = set(self.gene2weight.keys()).intersection( set(other.gene2weight.keys())) return frozendict( keyfilter(lambda k: k in genes, merge_with(max, self.gene2weight, other.gene2weight)))
def get_retriable_missing(self) -> List[Hash32]: retriable = cytoolz.keyfilter( lambda k: time.time() - k > self.reply_timeout, self.missing) self.missing = cytoolz.dissoc(self.missing, *retriable.keys()) return list(cytoolz.concat(retriable.values()))