def update_adjustments(self, adjustments, method): """ Merge ``adjustments`` with existing adjustments, handling index collisions according to ``method``. Parameters ---------- adjustments : dict[int -> list[Adjustment]] The mapping of row indices to lists of adjustments that should be appended to existing adjustments. method : {'append', 'prepend'} How to handle index collisions. If 'append', new adjustments will be applied after previously-existing adjustments. If 'prepend', new adjustments will be applied before previously-existing adjustments. """ try: merge_func = _merge_methods[method] except KeyError: raise ValueError( "Invalid merge method %s\n" "Valid methods are: %s" % (method, ', '.join(_merge_methods)) ) self.adjustments = merge_with( merge_func, self.adjustments, adjustments, )
def combine(tokens): def merge(xs): ts = tuple(toolz.take(2, xs)) if len(ts) != 1: raise ValueError("Repeated chemical") return ts[0] return toolz.merge_with(merge, *tokens)
def zhongji(ip='', username='', password=''): try: result = [] child = telnet(ip, username, password) child.sendline( "display cu section bbs-config | in link-aggregation") while True: index = child.expect([hw_prompt, hw_pager], timeout=120) if index == 0: result.append(child.before) child.sendline('quit') child.expect(':') child.sendline('y') child.close() break else: result.append(child.before) child.send(" ") continue except (pexpect.EOF, pexpect.TIMEOUT) as e: return ['fail', None, ip] rslt = ''.join(result).split('\r\n')[1:-1] rec = [x.replace('\x1b[37D', '').strip().split()[2:] for x in rslt if 'add-member' in x] def port(x): p = x[2].split(',') p1 = ['/'.join((x[1], y)) for y in p] return list(cons(x[0], p1)) ff = lambda x, y: merge_with(compose(unique, concat), x, y) rec1 = [port(x) for x in rec] rec2 = [{x[0]: x} for x in rec1] rec3 = reduce(ff, rec2, dict()) return ['success', rec3, ip]
def __add__(self, other): """ adds a value, and returns a new instance of UpdateDeltas """ if isinstance(other, UpdateDeltas): return UpdateDeltas(toolz.merge_with(utils.smart_sum, self.deltas, other.deltas)) else: return self.apply(lambda x: utils.smart_add(x, other))
def collect(self): from prometheus_client.core import GaugeMetricFamily, CounterMetricFamily yield GaugeMetricFamily( "dask_scheduler_clients", "Number of clients connected.", value=len([k for k in self.server.clients if k != "fire-and-forget"]), ) yield GaugeMetricFamily( "dask_scheduler_desired_workers", "Number of workers scheduler needs for task graph.", value=self.server.adaptive_target(), ) worker_states = GaugeMetricFamily( "dask_scheduler_workers", "Number of workers known by scheduler.", labels=["state"], ) worker_states.add_metric(["connected"], len(self.server.workers)) worker_states.add_metric(["saturated"], len(self.server.saturated)) worker_states.add_metric(["idle"], len(self.server.idle)) yield worker_states tasks = GaugeMetricFamily( "dask_scheduler_tasks", "Number of tasks known by scheduler.", labels=["state"], ) task_counter = toolz.merge_with( sum, (tp.states for tp in self.server.task_prefixes.values()) ) suspicious_tasks = CounterMetricFamily( "dask_scheduler_tasks_suspicious", "Total number of times a task has been marked suspicious", labels=["task_prefix_name"], ) for tp in self.server.task_prefixes.values(): suspicious_tasks.add_metric([tp.name], tp.suspicious) yield suspicious_tasks yield CounterMetricFamily( "dask_scheduler_tasks_forgotten", ( "Total number of processed tasks no longer in memory and already " "removed from the scheduler job queue. Note task groups on the " "scheduler which have all tasks in the forgotten state are not included." ), value=task_counter.get("forgotten", 0.0), ) for state in ALL_TASK_STATES: tasks.add_metric([state], task_counter.get(state, 0.0)) yield tasks
def __iadd__(self, other): """ mutates the UpdateDeltas by adding a value """ if isinstance(other, UpdateDeltas): self.deltas = toolz.merge_with(utils.smart_sum, self.deltas, other.deltas) else: self.iapply(lambda x: utils.smart_add(x, other)) return self
def __add__(self, other): """ adds a value, and returns a new instance of UpdateDeltas """ if isinstance(other, UpdateDeltas): return UpdateDeltas( toolz.merge_with(utils.smart_sum, self.deltas, other.deltas)) else: return self.apply(lambda x: utils.smart_add(x, other))
def merge(cls, args): if any(isinstance(a, dict) for a in args): return toolz.merge_with(cls.merge, *args, factory=cls) elif any(isinstance(a, list) for a in args): # TODO(kszucs): introduce a strategy argument to concatenate lists # instead of replacing # don't merge lists but needs to propagate factory return [cls.merge([a]) for a in toolz.last(args)] else: return toolz.last(args)
def __imul__(self, other): """ mutates the UpdateDeltas by multiplying a value adds a value, and returns a new instance of UpdateDeltas """ if isinstance(other, UpdateDeltas): self.deltas = toolz.merge_with(utils.smart_product, self.deltas, other.deltas) else: self.iapply(lambda x: utils.smart_mul(x, other)) return self
def test_curried_namespace(): namespace = {} def should_curry(func): if not callable(func) or isinstance(func, curry): return False nargs = num_required_args(func) if nargs is None or nargs > 1: return True else: return nargs == 1 and has_keywords(func) def curry_namespace(ns): return dict(( name, curry(f) if should_curry(f) else f, ) for name, f in ns.items() if '__' not in name) all_auto_curried = curry_namespace(vars(eth_utils)) inferred_namespace = valfilter(callable, all_auto_curried) curried_namespace = valfilter(callable, eth_utils.curried.__dict__) if inferred_namespace != curried_namespace: missing = set(inferred_namespace) - set(curried_namespace) if missing: to_insert = sorted("%s," % f for f in missing) raise AssertionError( 'There are missing functions in eth_utils.curried:\n' + '\n'.join(to_insert)) extra = set(curried_namespace) - set(inferred_namespace) if extra: raise AssertionError( 'There are extra functions in eth_utils.curried:\n' + '\n'.join(sorted(extra))) unequal = merge_with(list, inferred_namespace, curried_namespace) unequal = valfilter(lambda x: x[0] != x[1], unequal) to_curry = keyfilter(lambda x: should_curry(getattr(eth_utils, x)), unequal) if to_curry: to_curry_formatted = sorted('{0} = curry({0})'.format(f) for f in to_curry) raise AssertionError( 'There are missing functions to curry in eth_utils.curried:\n' + '\n'.join(to_curry_formatted)) elif unequal: not_to_curry_formatted = sorted(unequal) raise AssertionError( 'Missing functions NOT to curry in eth_utils.curried:\n' + '\n'.join(not_to_curry_formatted)) else: raise AssertionError("unexplained difference between %r and %r" % ( inferred_namespace, curried_namespace, ))
def test_word_counts_attributes(text_corpora: TextCorpora, corpus_collection: Dict[str, Corpus]) -> None: wc_cat1 = corpus_collection['cat1']\ .word_counts(as_strings=True) wc_cat2 = corpus_collection['cat2']\ .word_counts(as_strings=True) wc = tlz.merge_with(sum, wc_cat1, wc_cat2) assert text_corpora.word_counts(corpora='cat1', as_strings=True) == wc_cat1 assert text_corpora.word_counts(corpora='cat2', as_strings=True) == wc_cat2 assert text_corpora.word_counts(as_strings=True) == wc assert text_corpora.word_counts(corpora=['cat2', 'cat1'], as_strings=True) == wc
def main(): parser = cli_parser() args = parser.parse_args() if args.progress: try: from tqdm import tqdm args.verbose = False except ImportError: print('error: -p/--progress needs tqdm to be installed.') sys.exit(1) orthos, groups = read_orthos(args.orthos) if args.verbose: print('Read orthologs.') species = set() values = dict() name_pattern = re.compile(r'(?:\S+/)*(\w+)_(\w+)_values.tsv.gz', re.IGNORECASE) if args.progress: print('Reading values files...') pbar = tqdm(total=len(args.values)) for filename in args.values: m = name_pattern.match(filename) if m is None: print('Ignored {}'.format(filename)) continue sp_left, sp_right = m.group(1), m.group(2) species.add(sp_left) species.add(sp_right) this_values = read_values(orthos, groups, sp_left, sp_right, filename) values = merge_with(merge, values, this_values) if args.verbose: print(f'Read {filename}') elif args.progress: pbar.update(1) # At this point, we don't need the group number anymore values = filter_values(args.mode, species, values.values()) if args.progress: pbar.close() print('Computing the distances...') elif args.verbose: print('Computing the distances...') distances = scaled_L2norm(species, values) matrix = phylip(species, distances) with open(args.outfile, 'w') as f: f.write(matrix) f.write('\n')
def compute_source_r2(self, batch): from .timestepper import Batch from toolz import merge_with from .loss import weighted_r2_score, r2_score src = self.model(batch) # compute the apparent source batch = Batch(batch, self.prognostics) g = batch.get_known_forcings() progs = batch.data[self.prognostics] storage = progs.apply(lambda x: (x[1:] - x[:-1]) / self.time_step) forcing = g.apply(lambda x: (x[1:] + x[:-1]) / 2) src = src.apply(lambda x: (x[1:] + x[:-1]) / 2) true_src = storage - forcing * 86400 # copmute the metrics def wr2_score(args): x, y = args return weighted_r2_score(x, y, self.mass, dim=-3).item() r2s = merge_with(wr2_score, true_src, src) print(r2s) # compute the r2 of the integral pred_int = src.apply(lambda x: (x * self.mass).sum(-3)) true_int = true_src.apply(lambda x: (x * self.mass).sum(-3)) def scalar_r2_score(args): return r2_score(*args).item() def bias(args): x, y = args return (y.mean() - x.mean()).item() / 1000 r2s = merge_with(scalar_r2_score, true_int, pred_int) print(r2s) r2s = merge_with(bias, true_int, pred_int) print(r2s)
def __mul__(self, other): """ multiplies a value, and returns a new instance of UpdateDeltas """ if isinstance(other, UpdateDeltas): # TODO this will currently make it such that if one instance # has updates and another doesn't, it will return the same value # (another approach would be returning 0 if the value isn't in # both) # TODO is multiply by another set of deltas ever desired? return UpdateDeltas(toolz.merge_with(utils.smart_product, self.deltas, other.deltas)) else: return self.apply(lambda x: utils.smart_mul(x, other))
def test_curried_namespace(): exceptions = import_module('toolz.curried.exceptions') namespace = {} def should_curry(func): if not callable(func) or isinstance(func, toolz.curry): return False nargs = toolz.functoolz.num_required_args(func) if nargs is None or nargs > 1: return True return nargs == 1 and toolz.functoolz.has_keywords(func) def curry_namespace(ns): return { name: toolz.curry(f) if should_curry(f) else f for name, f in ns.items() if '__' not in name } from_toolz = curry_namespace(vars(toolz)) from_exceptions = curry_namespace(vars(exceptions)) namespace.update(toolz.merge(from_toolz, from_exceptions)) namespace = toolz.valfilter(callable, namespace) curried_namespace = toolz.valfilter(callable, toolz.curried.__dict__) if namespace != curried_namespace: missing = set(namespace) - set(curried_namespace) if missing: raise AssertionError( 'There are missing functions in toolz.curried:\n %s' % ' \n'.join(sorted(missing))) extra = set(curried_namespace) - set(namespace) if extra: raise AssertionError( 'There are extra functions in toolz.curried:\n %s' % ' \n'.join(sorted(extra))) unequal = toolz.merge_with(list, namespace, curried_namespace) unequal = toolz.valfilter(lambda x: x[0] != x[1], unequal) messages = [] for name, (orig_func, auto_func) in sorted(unequal.items()): if name in from_exceptions: messages.append( '%s should come from toolz.curried.exceptions' % name) elif should_curry(getattr(toolz, name)): messages.append('%s should be curried from toolz' % name) else: messages.append( '%s should come from toolz and NOT be curried' % name) raise AssertionError('\n'.join(messages))
def __mul__(self, other): """ multiplies a value, and returns a new instance of UpdateDeltas """ if isinstance(other, UpdateDeltas): # TODO this will currently make it such that if one instance # has updates and another doesn't, it will return the same value # (another approach would be returning 0 if the value isn't in # both) # TODO is multiply by another set of deltas ever desired? return UpdateDeltas( toolz.merge_with(utils.smart_product, self.deltas, other.deltas)) else: return self.apply(lambda x: utils.smart_mul(x, other))
def data_loader(dataset, batch_size, is_shuffle=True, drop_last=False, epoch_ratio=1.0, multi_process=False): idxs = list(range(len(dataset))) if is_shuffle: idxs = shuffle(idxs) idxs = idxs[0: int(len(idxs)*epoch_ratio)] # print("the idxs length:{}".format(len(idxs))) for idx in batch_holder(idxs, batch_size=batch_size)(): idx = idx[0] if drop_last and len(idx) != batch_size: # print("drop_last:{}".format(drop_last)) # print("len(idx) != batch_size: {}".format(len(idx) != batch_size)) # print("to break") break batch = [dataset[i] for i in idx] # print("before yield") yield toolz.merge_with(lambda x:x, batch)
def test_curried_namespace(): exceptions = import_module('toolz.curried.exceptions') namespace = {} def should_curry(func): if not callable(func) or isinstance(func, toolz.curry): return False nargs = toolz.functoolz.num_required_args(func) if nargs is None or nargs > 1: return True return nargs == 1 and toolz.functoolz.has_keywords(func) def curry_namespace(ns): return dict( (name, toolz.curry(f) if should_curry(f) else f) for name, f in ns.items() if '__' not in name ) from_toolz = curry_namespace(vars(toolz)) from_exceptions = curry_namespace(vars(exceptions)) namespace.update(toolz.merge(from_toolz, from_exceptions)) namespace = toolz.valfilter(callable, namespace) curried_namespace = toolz.valfilter(callable, toolz.curried.__dict__) if namespace != curried_namespace: missing = set(namespace) - set(curried_namespace) if missing: raise AssertionError('There are missing functions in toolz.curried:\n %s' % ' \n'.join(sorted(missing))) extra = set(curried_namespace) - set(namespace) if extra: raise AssertionError('There are extra functions in toolz.curried:\n %s' % ' \n'.join(sorted(extra))) unequal = toolz.merge_with(list, namespace, curried_namespace) unequal = toolz.valfilter(lambda x: x[0] != x[1], unequal) messages = [] for name, (orig_func, auto_func) in sorted(unequal.items()): if name in from_exceptions: messages.append('%s should come from toolz.curried.exceptions' % name) elif should_curry(getattr(toolz, name)): messages.append('%s should be curried from toolz' % name) else: messages.append('%s should come from toolz and NOT be curried' % name) raise AssertionError('\n'.join(messages))
def daily_position_pnl(self, dts): """ param dts: %Y-%m-%d --- 包括已有持仓以及当天关闭持仓的收益率 """ assert not self._dirty_portfolio, 'stats is accurate after end_session' symbol_pnl = dict() for asset, p in self.positions.items(): symbol_pnl[asset.sid] = p.amount * (p.last_sync_price - p.cost_basis) # 计算closed position dts = dts if isinstance(dts, str) else dts.strftime('%Y-%m-%d') closed_position = self.position_tracker.record_closed_position[dts] for p in closed_position: symbol_pnl[p.asset.sid] = p.amount * (p.last_sync_price - p.cost_basis) # merge same sid symbol_pnl = merge_with(sum, symbol_pnl) return symbol_pnl
def current_portfolio_weights(self): """ Compute each asset's weight in the portfolio by calculating its held value divided by the total value of all positions. Each equity's value is its price times the number of shares held. Each futures contract's value is its unit price times number of shares held times the multiplier. """ if self.positions: # due to asset varies from tag name --- different pipelines has the same sid p_values = valmap(lambda x: x.last_sync_price * x.amount, self.positions) p_values = keymap(lambda x: x.sid, p_values) aggregate = merge_with(sum, p_values) weights = pd.Series(aggregate) / self.portfolio_value else: weights = pd.Series(dtype='float') return weights.to_dict()
def zhongji(ip='', username='', password=''): try: result = [] child = telnet(ip, username, password) child.sendline("display cu section bbs-config | in link-aggregation") while True: index = child.expect([hw_prompt, hw_pager], timeout=120) if index == 0: result.append(child.before) child.sendline('quit') child.expect(':') child.sendline('y') child.close() break else: result.append(child.before) child.send(" ") continue except (pexpect.EOF, pexpect.TIMEOUT) as e: return ['fail', None, ip] rslt = ''.join(result).split('\r\n')[1:-1] rec = [ x.replace('\x1b[37D', '').strip().split()[2:] for x in rslt if 'add-member' in x ] def port(x): p = x[2].split(',') p1 = ['/'.join((x[1], y)) for y in p] return list(cons(x[0], p1)) ff = lambda x, y: merge_with(compose(unique, concat), x, y) rec1 = [port(x) for x in rec] rec2 = [{x[0]: x} for x in rec1] rec3 = reduce(ff, rec2, dict()) return ['success', rec3, ip]
def concat_dicts(seq, dim=1): return merge_with(lambda x: torch.cat(x, dim=dim), *seq)
def merge_frequencies(seqs): return list(merge_with(sum, map(dict, seqs)).items())
def stack_dicts(delayeds): return merge_with(lambda x: torch.stack(x, dim=1), delayeds)
def combine(name, start, tokens): end = pp.getTokensEndLoc() return {name[start:end]: toolz.merge_with(sum, *tokens)}
def smart_dict_merge(*args): """If values are lists they are merged, else the last element is chosen as new value.""" return toolz.merge_with(merge_or_last, *args)
def main(): parser = cli_parser() args = parser.parse_args() if args.progress: try: from tqdm import tqdm args.verbose = False except ImportError: print('error: -p/--progress needs tqdm to be installed.') sys.exit(1) orthos, groups = read_orthos(args.orthos) if args.verbose: print('Read orthologs.') species = set() values = dict() name_pattern = re.compile('(?:\S+/)*(\w+)_(\w+)_values.tsv.gz', re.IGNORECASE) if args.progress: print('Reading values files...') pbar = tqdm(total=len(args.values)) for filename in args.values: m = name_pattern.match(filename) if m is None: print('Ignored {}'.format(filename)) continue sp_left, sp_right = m.group(1), m.group(2) species.add(sp_left) species.add(sp_right) this_values = read_values(orthos, groups, sp_left, sp_right, filename) values = merge_with(merge, values, this_values) if args.verbose: print(f'Read {filename}') elif args.progress: pbar.update(1) if args.progress: pbar.close() res = {} # At this point, we don't need the group number anymore values = filter_values(args.mode, species, values.values()) res['TotalSize'] = len(values) informatives = [v for v in values if len(set(v.values())) != 1] res['InformativesSize'] = len(informatives) res['PercentageInformative'] = float(len(informatives)) / float( len(values)) * 100.0 non_informative_value = float('nan') for v in values: if len(set(v.values())) == 1: non_informative_value = next(iter(v.values())) break res['NonInformativeValue'] = non_informative_value if args.output is None: for k, v in res.items(): print(f'{k}: {v}') else: _, ext = os.path.splitext(args.output) if ext == '.json': data = json.dumps(res) elif ext == '.tsv': data = '\n'.join([f'{k}\t{v}' for k, v in res.items()]) else: print(f'Error: unknown output format extension: {ext}') sys.exit(1) with open(args.output, 'w') as f: f.write(data) f.write('\n')
def merge_with(fn, *dicts): if len(dicts) == 0: raise TypeError() else: return toolz.merge_with(fn, *dicts)
def concat_dicts(delayeds, dim=1): return merge_with(lambda x: torch.cat(x, dim=dim), *delayeds)
def merge_with(func, d, *dicts, **kwargs): return toolz.merge_with(func, d, *dicts, **kwargs)
def main(): parser = cli_parser() args = parser.parse_args() if args.progress: try: from tqdm import tqdm args.verbose = False except ImportError: print('error: -p/--progress needs tqdm to be installed.') sys.exit(1) if not args.one_file: try: os.mkdir(args.outdir) if args.verbose: print(f'directory {args.outdir} created.') except FileExistsError: print('error: {} already exists.'.format(args.outdir)) sys.exit(1) orthos, groups = read_orthos(args.orthos) if args.verbose: print('Read orthologs.') species = set() values = dict() name_pattern = re.compile('(?:\S+/)*(\w+)_(\w+)_values.tsv.gz', re.IGNORECASE) if args.progress: print('Reading values files...') pbar = tqdm(total=len(args.values)) for filename in args.values: m = name_pattern.match(filename) if m is None: print('Ignored {}'.format(filename)) continue sp_left, sp_right = m.group(1), m.group(2) species.add(sp_left) species.add(sp_right) this_values = read_values(orthos, groups, sp_left, sp_right, filename) values = merge_with(merge, values, this_values) if args.verbose: print(f'Read {filename}') elif args.progress: pbar.update(1) # At this point, we don't need the group number anymore filter_values('intersection', species, values.values()) if args.progress: pbar.close() print('Computing the distances...') pbar = tqdm(total=len(values)) elif args.verbose: print('Computing the distances...') if args.one_file: f = open(args.outdir, 'w') for i, v in enumerate(values): distances = scaled_L2norm(species, [v]) matrix = phylip(species, distances) if args.one_file: f.write(matrix) f.write('\n') else: filename = f'{args.outdir}/matrix_{i}.phylip' with open(filename, 'w') as f: f.write(matrix) f.write('\n') if args.verbose: print(f' Written {filename}') if args.progress: pbar.update(1) if args.one_file: f.close() if args.progress: pbar.close() print() # To have a pretty line in the console
def merge_with(func, *dicts): if len(dicts) == 0: raise TypeError("No input") return toolz.merge_with(func, *dicts)
def schedule_batch(_batch_dicts): """Send a batch to background worker, and reset _dicts container""" _batch = merge_with(custom_merge, *_batch_dicts) if _batch: batch_update_async.delay(_batch)
def stack_dicts(seq): return merge_with(lambda x: torch.stack(x, dim=1), seq)
def merge_with(fn, *dicts, **kwargs): if len(dicts) == 0: raise TypeError() else: return toolz.merge_with(fn, *dicts, **kwargs)
return word.lower().rstrip(",.!)-*_?:;$'-\"").lstrip("-*'\"(_$'") wordcount = compose(frequencies, map(stem), concat, map(str.split), open) if __name__ == '__main__': start_time = datetime.datetime.now() # Filenames for thousands of books from which we'd like to count words filenames = ['book_%d.txt' % i for i in range(1)] # Start with sequential map for development pmap = map # Advance to Multiprocessing map for heavy computation on single machine # from multiprocessing import Pool # p = Pool(8) # pmap = p.map # Finish with distributed parallel map for big data # from IPython.parallel import Client # p = Client()[:] # pmap = p.map_sync total = merge_with(sum, pmap(wordcount, filenames)) print 'total:', total end_time = datetime.datetime.now() c = end_time - start_time print '%d seconds %d microseconds' % (c.seconds, c.microseconds)
def _cross_validate(estimator: BaseEstimator, model_id: int, grid_search_context: Dict[str, Any]) -> Dict[str, Any]: n_splits = grid_search_context['cross_validation'] X_train = grid_search_context['X_train'] y_train = grid_search_context['y_train'] fit_params = grid_search_context['fit_params'] cross_validation_results = [] k_folds = KFold(n_splits=n_splits) for cv_train, cv_test in k_folds.split(X_train): logger.info("Training model {} on cross validation training set."\ .format(model_id)) cv_train_start = time() estimator.fit(X_train.iloc[cv_train], y_train.iloc[cv_train], **fit_params) cv_train_stop = time() logger.info( "Completed training model {} on cross validation "\ .format(model_id) + "training set. Took {:.3f} seconds."\ .format(cv_train_stop - cv_train_start)) logger.info("Evaluating model {} on cross validation training set."\ .format(model_id)) cv_training_results = \ _evaluate_model( estimator, X_train.iloc[cv_train], y_train.iloc[cv_train], grid_search_context, "cross_validation_training") logger.info("Completed evaluating model {} on cross validation "\ .format(model_id) + "training set. Took {:.3f} seconds for {} records.".format( cv_training_results[ "cross_validation_training_total_prediction_time"], cv_training_results[ "cross_validation_training_total_prediction_records"])) logger.info("Evaluating model {} on cross validation test set."\ .format(model_id)) cv_validation_results = \ _evaluate_model(estimator, X_train.iloc[cv_test], y_train.iloc[cv_test], grid_search_context, "cross_validation") logger.info("Completed evaluating model {} on cross validation "\ .format(model_id) + "test set. Took {:.3f} seconds for {} records.".format( cv_validation_results[ "cross_validation_total_prediction_time"], cv_validation_results[ "cross_validation_total_prediction_records"])) cross_validation_results.append( { "cross_validation_training_time_total": cv_train_stop - cv_train_start, **cv_training_results, **cv_validation_results }) # Merge the results. cv_results_merged = merge_with(identity, *cross_validation_results) cv_results = { # These are the results for the individual folds. **(keymap(lambda x: x + "_all", cv_results_merged)), # These are the average results. **(valmap(lambda x: sum(x) / len(x), cv_results_merged)) } logger.info("Cross validation for model {} completed.".format(model_id)) return cv_results
t = t0 for i in tqdm(range(n)): out = step_model(model.step, dt, x['layer_mass'], qt, sl, x['FQT'](t), x['FSL'](t), x['U'](t), x['V'](t), x['SST'](t), x['SOLIN'](t)) qt = out['qt'] sl = out['sl'] t += dt out['time'] = t yield out dt = 30 # seconds day = 86400 nt = 40 * day // dt nout = int(3600 // dt) model_path = "./10_test_db/4.pkl" data_path = "./all.1.zarr" model = load_model(model_path) x = load_data(data_path) output = merge_with(np.stack, take_nth(nout, run(model, x, 0.0, nt, dt=dt))) plt.contourf(output['time'], x['p'], output['qt'].squeeze().T) plt.ylim([1000, 10]) plt.colorbar() plt.show()