def plot_mode_connections_for_minima(p1, t1, config, max_task=None):
    seq_cons, mtl_cons = [], []
    seq_labels, mtl_labels = [], []
    segments = []
    if max_task is None:
        max_task = config['num_tasks']
    for t2 in range(t1+1, max_task+1):
        seq_con = get_mode_connections(p1, t1, 'seq', t2, t1, config)
        # mtl_con = get_mode_connections(p1, t1, 'lmc', t2, t1, config)
        segments = seq_con['ts']
        seq_labels.append(r"$\hat{{w}}_{} \rightarrow \hat{{w}}_{{{}}}$".format(t1, t2))
        # mtl_labels.append(r"$\hat{{w}}_{} \rightarrow \bar{{w}}_{{{}}}$".format(t1, t2))
        seq_cons.append(seq_con['loss'])
        # mtl_cons.append(mtl_con['loss'])
    # print("DEBUG MC >> len(labels)=", len(seq_cons+mtl_cons))
    save_path = path='{}/mc_on_{}_max_{}'.format(config['exp_dir'], t1, max_task)
    plot_multi_interpolations(x=segments, ys=seq_cons + mtl_cons ,y_labels=seq_labels+mtl_labels, path=save_path)
Ejemplo n.º 2
0
def get_custom_mode_connections_for_minima(p1, t1, config):
    seq_cons, mtl_cons = [], []
    seq_labels, mtl_labels = [], []
    segments = []

    for t2 in [5, 10, 15, 20]:
        if t2 < t1:
            continue
        # seq_con = get_mode_connections(p1, t1, 'seq', t2, t1, config)
        mtl_con = get_mode_connections(p1, t1, 'mtl', t2, t1, config)
        segments = mtl_con['ts']
        # seq_labels.append(r"$\hat{{w}}_{} \rightarrow \hat{{w}}_{}$".format(t1, t2))
        mtl_labels.append(r"$\hat{{w}}_{} \rightarrow w^*_{}$".format(t1, t2))
        # seq_cons.append(seq_con['loss'])
        mtl_cons.append(mtl_con['loss'])
    # print("DEBUG MC >> len(labels)=", len(seq_cons+mtl_cons))
    save_path = path = '{}/mc_on_{}_custom'.format(config['exp_dir'], t1)
    plot_multi_interpolations(x=segments,
                              ys=seq_cons + mtl_cons,
                              y_labels=seq_labels + mtl_labels,
                              path=save_path)