def configure_next_level(lvl: int, depth: int, budget: int = 50): new_exp_key = f'covid-{lvl}' src_trials = MongoTrials('mongo://localhost:1234/covid/jobs', exp_key=f'covid-{lvl-1}') all_trials = MongoTrials('mongo://localhost:1234/covid/jobs') dest_trials = MongoTrials('mongo://localhost:1234/covid/jobs', exp_key=new_exp_key) hist_length = {2: 3, 3: 5, 5: 8}.get(depth, 10) forward_losses = [] for trial, loss in zip(src_trials.trials, src_trials.losses()): if loss is None: forward_losses.append(None) continue v_x, vloss, _, _ = zip(*trial['result']['validation_stats']) slope, intercept, _, _, _ = linregress(v_x[-hist_length:], vloss[-hist_length:]) forward_losses.append( min( 0.5 * (loss + intercept + slope * v_x[-1] + slope * (1 - 0.8**(depth - v_x[-1])) / (1 - 0.8)), loss)) ordered_idxs = list( np.argsort([x if x is not None else np.inf for x in forward_losses])) last_tid = 0 if len(all_trials.tids) == 0 else max(all_trials.tids) available_tids = [] result_docs = [] while len(ordered_idxs) > 0: idx = ordered_idxs.pop(0) if src_trials.losses()[idx] is None: continue epochs_completed = src_trials.trials[idx]['result'].get( 'training_loss_hist', [(0, np.inf)])[-1][0] spec = None result = {'status': 'new'} misc = copy.deepcopy(src_trials.trials[idx]['misc']) result_docs.append((spec, result, misc)) budget -= (depth - epochs_completed) if budget <= 0: break while len(ordered_idxs) > 0: idx = ordered_idxs.pop() if src_trials.losses()[idx] is None: continue if len(available_tids) == 0: available_tids = dest_trials.new_trial_ids(last_tid) tid = available_tids.pop(0) last_tid = tid # copy in the ones that aren't worth exploring further cpy = copy.deepcopy(src_trials.trials[idx]) cpy['exp_key'] = new_exp_key cpy['tid'] = tid cpy['misc']['tid'] = tid cpy['misc']['idxs'] = {k: [tid] for k in cpy['misc']['idxs'].keys()} del cpy['_id'] dest_trials.insert_trial_doc(cpy) return result_docs