Ejemplo n.º 1
0
def main(args):
    config = get_config(args.model)
    net = load_model_path(args.model, config)

    t_len = 200
    
    with torch.no_grad():
        net.reset()

        ins = torch.zeros((12))

        for j in range(t_len):
Ejemplo n.º 2
0
def main(args):
    config = get_config(args.model, to_bunch=True)
    net = load_model_path(args.model, config)

    data, loss = test_model(net, config, n_tests=150)

    ys = [[] for i in range(net.args.T)]
    ys_stim = [[] for i in range(net.args.T)]

    for d in data:
        context, idx, trial, x, y, out, loss = d
        y_ready, y_set, y_go = trial.rsg
        y_prod = np.argmax(out >= 1)
        t_y = y_set - y_ready
        t_p = y_prod - y_set
        t_ym = y_go - y_set

        ys[context].append((t_y, t_p))
        ys_stim[context].append((t_y, t_ym))

    ys = [list(zip(*np.array(y))) for y in ys]
    ys_stim = [list(zip(*np.array(y))) for y in ys_stim]

    for i in range(net.args.T):
        plt.scatter(ys_stim[i][0],
                    ys_stim[i][1],
                    marker='o',
                    c='black',
                    s=20,
                    edgecolors=cols[i])
        plt.scatter(ys[i][0], ys[i][1], color=cols[i], s=15)

    plt.xlabel('desired t_p')
    plt.ylabel('produced t_p')

    plt.show()
Ejemplo n.º 3
0
def main(args):
    config = get_config(args.model, to_bunch=True)
    net = load_model_path(args.model, config)

    if len(args.dataset) == 0:
        args.dataset = config.dataset

    n_reps = 1000
    # don't show these contexts
    context_filter = []

    _, loader = create_loaders(args.dataset,
                               config,
                               split_test=False,
                               test_size=n_reps,
                               context_filter=context_filter)
    x, y, trials = next(iter(loader))
    A = get_states(net, x)

    t_type = type(trials[0])
    if t_type == RSG:
        pca_rsg(args, A, trials, n_reps)
    elif t_type in [DelayProAnti, MemoryProAnti]:
        pca_dmpa(args, A, trials, n_reps)
Ejemplo n.º 4
0
def main(args):

    files = os.scandir(args.dir)
    for fn in os.scandir(args.dir):
        if os.path.basename(fn).startswith('checkpoints'):
            ckpt_folder = fn
        elif os.path.basename(fn).startswith('config'):
            config_path = fn

    models = os.scandir(ckpt_folder)
    with open(config_path, 'r') as f:
        config = json.load(f)
    
    norms = []

    for ix, s in enumerate(models):
        print(s.path)
        model = load_model_path(s.path, config)
        # J = model.reservoir.J.weight.data.numpy()
        Wf = model.W_f.weight.data.numpy()

        if ix == 0:
            # last_J = J
            last_Wf = Wf
            continue

        # dif = J - last_J
        dif = Wf - last_Wf
        norms.append(np.linalg.norm(dif))

        # last_J = J
        last_Wf = Wf


    plt.plot(norms)
    plt.savefig('figures/weight_changes.png')
Ejemplo n.º 5
0
config = get_config(args.model)

if args.noise != 0:
    J = model['W_f.weight']
    v = J.std()
    shp = J.shape
    model['W_f.weight'] += torch.normal(0, v * .5, shp)

    J = model['W_ro.weight']
    v = J.std()
    shp = J.shape
    model['W_ro.weight'] += torch.normal(0, v * .5, shp)

config = fill_undefined_args(args, config, overwrite_none=True)

net = load_model_path(args.model, config=config)

if args.test_all:
    _, loss2 = test_model(net, config)
    print('avg summed loss (all):', loss2)

if not args.no_plot:
    data, loss = test_model(net, config, n_tests=6)
    print('avg summed loss (plotted):', loss)

    run_id = '/'.join(args.model.split('/')[-3:-1])

    fig, ax = plt.subplots(2, 3, sharex=True, sharey=True, figsize=(12, 7))

    if 'goals' in config.dataset:
        p_fn = get_potential(config)
Ejemplo n.º 6
0
dt = dt[dt.mnoise == 6]

intervals = [{}, {}, {}]
for j, dset in enumerate(dsets):
    subset = dt[dt.dset == dset]

    for iterr in range(len(subset)):

        job_id = subset.iloc[iterr].slurm_id

        model_folder = os.path.join('..', 'logs', run_id, str(job_id))
        model_path = os.path.join(model_folder, 'model_best.pth')
        config = get_config(model_path, ctype='model', to_bunch=True)
        config.m_noise = 0
        config.dataset = dset_map[config.dataset]
        net = load_model_path(model_path, config=config)

        data, loss = test_model(net, config, n_tests=200, dset_base='../')
        dset = load_rb(os.path.join('..', config.dataset))

        distr = {}

        for k in range(len(data)):
            dset_idx, x, _, z, _ = data[k]
            r, s, g = dset[dset_idx][2]

            t_first = torch.nonzero(z >= 1)
            if len(t_first) > 0:
                t_first = t_first[0, 0]
            else:
                t_first = len(x)
Ejemplo n.º 7
0
with open(args.model, 'rb') as f:
    model = torch.load(f)

if args.noise != 0:
    J = model['W_f.weight']
    v = J.std()
    shp = J.shape
    model['W_f.weight'] += torch.normal(0, v * .5, shp)

    J = model['W_ro.weight']
    v = J.std()
    shp = J.shape
    model['W_ro.weight'] += torch.normal(0, v * .5, shp)

net = load_model_path(args.model, params={'dset': args.dataset, 'out_act': args.out_act})
dset = load_rb(args.dataset)
data = test_model(net, dset, n_tests=0)

run_id = '/'.join(args.model.split('/')[-3:-1])

fig, ax = plt.subplots(3,4,sharex=True, sharey=True, figsize=(12,7))

for i, ax in enumerate(fig.axes):
    ix, x, y, z, loss = data[i]
    xr = np.arange(len(x))

    ax.axvline(x=0, color='dimgray', alpha = 1)
    ax.axhline(y=0, color='dimgray', alpha = 1)
    ax.grid(True, which='major', lw=1, color='lightgray', alpha=0.4)
    ax.spines['top'].set_visible(False)