def get_mode_connections(p1, t1, p2, t2, eval_task, config): w1 = flatten_params(load_task_model_by_policy(t1, p1, config['exp_dir'])) w2 = flatten_params(load_task_model_by_policy(t2, p2, config['exp_dir'])) loss, acc, ts = calculate_mode_connectivity(w1, w2, loaders['sequential'][eval_task]['val'], config) save_path = '{}/mc_{}_{}_to_{}_{}_on_{}'.format(config['exp_dir'],p1, t1, p2, t2, eval_task) res = {'loss': loss, 'acc': acc, 'ts': ts} save_np_arrays(res, path=save_path) return res
def plot_loss_plane(w, eval_loader, path, w_labels, config): u = w[2] - w[0] dx = np.linalg.norm(u) u /= dx v = w[1] - w[0] v -= np.dot(u, v) * u dy = np.linalg.norm(v) v /= dy m = load_task_model_by_policy(0, 'init', config['exp_dir']) m.eval() coords = np.stack(get_xy(p, w[0], u, v) for p in w) # print("coords", coords) G = 15 margin = 0.2 alphas = np.linspace(0.0 - margin, 1.0 + margin, G) betas = np.linspace(0.0 - margin, 1.0 + margin, G) tr_loss = np.zeros((G, G)) grid = np.zeros((G, G, 2)) for i, alpha in enumerate(alphas): for j, beta in enumerate(betas): p = w[0] + alpha * dx * u + beta * dy * v m = assign_weights(m, p).to(DEVICE) err = eval_single_epoch(m, eval_loader)['loss'] c = get_xy(p, w[0], u, v) #print(c) grid[i, j] = [alpha * dx, beta * dy] tr_loss[i, j] = err contour = {'grid': grid, 'values': tr_loss, 'coords': coords} save_np_arrays(contour, path=path) plot_contour(grid, tr_loss, coords, log_alpha=-5.0, N=7, path=path, w_labels=w_labels, dataset='mnist') #config['dataset']) return contour
def plot_cka(p1, t1, p2, t2, eval_task, config): m1 = load_task_model_by_policy(t1, p1, config['exp_dir']) m2 = load_task_model_by_policy(t2, p2, config['exp_dir']) save_path = '{}/cka_on_{}_{}_vs_{}_{}'.format(config['exp_dir'], p1, t1, p2, t2) scores, keys = calculate_CKA(m1, m2, loaders['sequential'][eval_task]['val'], num_batches=50) res = {'scores': scores, 'keys': keys} save_np_arrays(res, path=save_path) ylabel = r'$w^*_{}$'.format( t1) if p1 == 'mtl' else r'$\hat{{w}}_{}$'.format(t1) xlabel = r'$w^*_{}$'.format( t2) if p2 == 'mtl' else r'$\hat{{w}}_{}$'.format(t2) plot_heat_map(scores, keys, save_path, xlabel, ylabel) return res