Exemplo n.º 1
0
def configure_next_level(lvl: int, depth: int, budget: int = 50):
    new_exp_key = f'covid-{lvl}'

    src_trials = MongoTrials('mongo://localhost:1234/covid/jobs',
                             exp_key=f'covid-{lvl-1}')
    all_trials = MongoTrials('mongo://localhost:1234/covid/jobs')
    dest_trials = MongoTrials('mongo://localhost:1234/covid/jobs',
                              exp_key=new_exp_key)

    hist_length = {2: 3, 3: 5, 5: 8}.get(depth, 10)

    forward_losses = []
    for trial, loss in zip(src_trials.trials, src_trials.losses()):
        if loss is None:
            forward_losses.append(None)
            continue

        v_x, vloss, _, _ = zip(*trial['result']['validation_stats'])

        slope, intercept, _, _, _ = linregress(v_x[-hist_length:],
                                               vloss[-hist_length:])
        forward_losses.append(
            min(
                0.5 * (loss + intercept + slope * v_x[-1] + slope *
                       (1 - 0.8**(depth - v_x[-1])) / (1 - 0.8)), loss))

    ordered_idxs = list(
        np.argsort([x if x is not None else np.inf for x in forward_losses]))

    last_tid = 0 if len(all_trials.tids) == 0 else max(all_trials.tids)
    available_tids = []

    result_docs = []

    while len(ordered_idxs) > 0:
        idx = ordered_idxs.pop(0)
        if src_trials.losses()[idx] is None:
            continue

        epochs_completed = src_trials.trials[idx]['result'].get(
            'training_loss_hist', [(0, np.inf)])[-1][0]

        spec = None
        result = {'status': 'new'}
        misc = copy.deepcopy(src_trials.trials[idx]['misc'])

        result_docs.append((spec, result, misc))
        budget -= (depth - epochs_completed)
        if budget <= 0:
            break

    while len(ordered_idxs) > 0:
        idx = ordered_idxs.pop()

        if src_trials.losses()[idx] is None:
            continue

        if len(available_tids) == 0:
            available_tids = dest_trials.new_trial_ids(last_tid)

        tid = available_tids.pop(0)
        last_tid = tid

        # copy in the ones that aren't worth exploring further
        cpy = copy.deepcopy(src_trials.trials[idx])
        cpy['exp_key'] = new_exp_key
        cpy['tid'] = tid
        cpy['misc']['tid'] = tid

        cpy['misc']['idxs'] = {k: [tid] for k in cpy['misc']['idxs'].keys()}

        del cpy['_id']

        dest_trials.insert_trial_doc(cpy)

    return result_docs