def main(args): config = get_config(args.model) net = load_model_path(args.model, config) t_len = 200 with torch.no_grad(): net.reset() ins = torch.zeros((12)) for j in range(t_len):
def main(args): config = get_config(args.model, to_bunch=True) net = load_model_path(args.model, config) data, loss = test_model(net, config, n_tests=150) ys = [[] for i in range(net.args.T)] ys_stim = [[] for i in range(net.args.T)] for d in data: context, idx, trial, x, y, out, loss = d y_ready, y_set, y_go = trial.rsg y_prod = np.argmax(out >= 1) t_y = y_set - y_ready t_p = y_prod - y_set t_ym = y_go - y_set ys[context].append((t_y, t_p)) ys_stim[context].append((t_y, t_ym)) ys = [list(zip(*np.array(y))) for y in ys] ys_stim = [list(zip(*np.array(y))) for y in ys_stim] for i in range(net.args.T): plt.scatter(ys_stim[i][0], ys_stim[i][1], marker='o', c='black', s=20, edgecolors=cols[i]) plt.scatter(ys[i][0], ys[i][1], color=cols[i], s=15) plt.xlabel('desired t_p') plt.ylabel('produced t_p') plt.show()
def main(args): config = get_config(args.model, to_bunch=True) net = load_model_path(args.model, config) if len(args.dataset) == 0: args.dataset = config.dataset n_reps = 1000 # don't show these contexts context_filter = [] _, loader = create_loaders(args.dataset, config, split_test=False, test_size=n_reps, context_filter=context_filter) x, y, trials = next(iter(loader)) A = get_states(net, x) t_type = type(trials[0]) if t_type == RSG: pca_rsg(args, A, trials, n_reps) elif t_type in [DelayProAnti, MemoryProAnti]: pca_dmpa(args, A, trials, n_reps)
def main(args): files = os.scandir(args.dir) for fn in os.scandir(args.dir): if os.path.basename(fn).startswith('checkpoints'): ckpt_folder = fn elif os.path.basename(fn).startswith('config'): config_path = fn models = os.scandir(ckpt_folder) with open(config_path, 'r') as f: config = json.load(f) norms = [] for ix, s in enumerate(models): print(s.path) model = load_model_path(s.path, config) # J = model.reservoir.J.weight.data.numpy() Wf = model.W_f.weight.data.numpy() if ix == 0: # last_J = J last_Wf = Wf continue # dif = J - last_J dif = Wf - last_Wf norms.append(np.linalg.norm(dif)) # last_J = J last_Wf = Wf plt.plot(norms) plt.savefig('figures/weight_changes.png')
config = get_config(args.model) if args.noise != 0: J = model['W_f.weight'] v = J.std() shp = J.shape model['W_f.weight'] += torch.normal(0, v * .5, shp) J = model['W_ro.weight'] v = J.std() shp = J.shape model['W_ro.weight'] += torch.normal(0, v * .5, shp) config = fill_undefined_args(args, config, overwrite_none=True) net = load_model_path(args.model, config=config) if args.test_all: _, loss2 = test_model(net, config) print('avg summed loss (all):', loss2) if not args.no_plot: data, loss = test_model(net, config, n_tests=6) print('avg summed loss (plotted):', loss) run_id = '/'.join(args.model.split('/')[-3:-1]) fig, ax = plt.subplots(2, 3, sharex=True, sharey=True, figsize=(12, 7)) if 'goals' in config.dataset: p_fn = get_potential(config)
dt = dt[dt.mnoise == 6] intervals = [{}, {}, {}] for j, dset in enumerate(dsets): subset = dt[dt.dset == dset] for iterr in range(len(subset)): job_id = subset.iloc[iterr].slurm_id model_folder = os.path.join('..', 'logs', run_id, str(job_id)) model_path = os.path.join(model_folder, 'model_best.pth') config = get_config(model_path, ctype='model', to_bunch=True) config.m_noise = 0 config.dataset = dset_map[config.dataset] net = load_model_path(model_path, config=config) data, loss = test_model(net, config, n_tests=200, dset_base='../') dset = load_rb(os.path.join('..', config.dataset)) distr = {} for k in range(len(data)): dset_idx, x, _, z, _ = data[k] r, s, g = dset[dset_idx][2] t_first = torch.nonzero(z >= 1) if len(t_first) > 0: t_first = t_first[0, 0] else: t_first = len(x)
with open(args.model, 'rb') as f: model = torch.load(f) if args.noise != 0: J = model['W_f.weight'] v = J.std() shp = J.shape model['W_f.weight'] += torch.normal(0, v * .5, shp) J = model['W_ro.weight'] v = J.std() shp = J.shape model['W_ro.weight'] += torch.normal(0, v * .5, shp) net = load_model_path(args.model, params={'dset': args.dataset, 'out_act': args.out_act}) dset = load_rb(args.dataset) data = test_model(net, dset, n_tests=0) run_id = '/'.join(args.model.split('/')[-3:-1]) fig, ax = plt.subplots(3,4,sharex=True, sharey=True, figsize=(12,7)) for i, ax in enumerate(fig.axes): ix, x, y, z, loss = data[i] xr = np.arange(len(x)) ax.axvline(x=0, color='dimgray', alpha = 1) ax.axhline(y=0, color='dimgray', alpha = 1) ax.grid(True, which='major', lw=1, color='lightgray', alpha=0.4) ax.spines['top'].set_visible(False)