Exemplo n.º 1
0
    def __init__(self,
                 inp_dim,
                 rnn_dim,
                 out_dim,
                 basis=None,
                 train_mask=None,
                 activity_reg=0,
                 col_len=None,
                 decode_func=cos_sin_decode):
        self.col_len = col_len
        self.decode_func = decode_func

        if train_mask is None:
            train_mask = np.ones(inp_dim, dtype=bool)
        self.dec = nn.Linear(rnn_dim, out_dim, bias=True)
        self.rnn = nn.RNN(inp_dim, rnn_dim, 1, nonlinearity='relu')
        self.net = students.GenericRNN(self.rnn,
                                       students.GausId(out_dim),
                                       decoder=self.dec,
                                       z_dist=students.GausId(rnn_dim),
                                       beta=activity_reg)

        if basis is not None:
            with torch.no_grad():
                inp_w = np.append(basis[:, n_out:n_out + n_in - 1],
                                  np.ones((N, 1)) / np.sqrt(N),
                                  axis=-1)
                net.rnn.weight_ih_l0.copy_(torch.tensor(inp_w).float())
                net.rnn.weight_ih_l0.requires_grad = False
        with torch.no_grad():
            weight_mask = torch.tensor(np.ones_like(self.net.rnn.weight_ih_l0))
            weight_mask[:, train_mask] = 0
            ident_weights = torch.tensor(np.identity(rnn_dim),
                                         dtype=torch.float)
            self.net.rnn.weight_ih_l0[:, train_mask] = ident_weights
            self.net.rnn.weight_ih_l0.register_hook(
                lambda grad: grad.mul_(weight_mask))
Exemplo n.º 2
0
    def __init__(self, c=None, n=None, overlap=0, d=None, use_mse=False):
        """overlap is given as the log2 of the dot product on their +/-1 representation"""
        super(RandomDichotomies, self).__init__()

        if d is None:
            if c is None:
                raise ValueError('Must supply either (c,n), or d')
            if n > c:
                raise ValueError(
                    'Cannot have more dichotomies than conditions!!')

            if overlap == 0:
                # generate uncorrelated dichotomies, only works for powers of 2
                H = la.hadamard(c)[:, 1:]
                pos = np.nonzero(
                    H[:, np.random.choice(c - 1, n, replace=False)] > 0)
                self.positives = [pos[0][pos[1] == d] for d in range(n)]
            elif overlap == 1:
                prot = 2 * (np.random.permutation(c) >= (c / 2)) - 1
                pos = np.where(prot > 0)[0]
                neg = np.where(prot < 0)[0]
                idx = np.random.choice((c // 2)**2, n - 1, replace=False)
                # print(idx)
                swtch = np.stack((pos[idx % (c // 2)], neg[idx // (c // 2)])).T
                # print(swtch)
                ps = np.ones((n - 1, 1)) * prot
                ps[np.arange(n - 1), swtch[:, 0]] *= -1
                ps[np.arange(n - 1), swtch[:, 1]] *= -1
                pos = [np.nonzero(p > 0)[0] for p in ps]
                pos.append(np.nonzero(prot > 0)[0])
                self.positives = pos
        else:
            self.positives = d
            n = len(self.positives)
            if c is None:
                c = 2 * len(self.positives[0])

        self.__name__ = 'RandomDichotomies_%d-%d-%d' % (c, n, overlap)
        self.num_var = n
        self.dim_output = n
        self.num_cond = c

        if use_mse:
            self.obs_distribution = students.GausId(n)
        else:
            self.obs_distribution = students.Bernoulli(n)
        self.link = None
Exemplo n.º 3
0
    def __init__(self, d, function_class, use_mse=False):
        """overlap is given as the log2 of the dot product on their +/-1 representation"""
        super(LogicalFunctions, self).__init__()

        self.__name__ = 'LogicalFunctions_%dbit-%d' % (len(d), function_class)
        self.num_var = 1
        self.dim_output = 1
        self.num_cond = 2**len(d)

        self.bits = d
        self.function_class = function_class

        self.positives = [
            np.nonzero(self(np.arange(self.num_cond).squeeze()).numpy())[0]
        ]
        # print(self(np.arange(self.num_cond)))
        # print(self.positives)

        if use_mse:
            self.obs_distribution = students.GausId(1)
        else:
            self.obs_distribution = students.Bernoulli(1)
        self.link = None
Exemplo n.º 4
0
    def __init__(self, n, q=None, use_mse=False):
        """overlap is given as the log2 of the dot product on their +/-1 representation"""
        super(StandardBinary, self).__init__()

        if q is None:
            q = n

        bits = np.nonzero(
            1 -
            np.mod(np.arange(2**n)[:, None] // (2**np.arange(n)[None, :]), 2))
        pos_conds = np.split(bits[0][np.argsort(bits[1])], n)[:q]

        self.positives = pos_conds
        self.__name__ = 'StandardBinary%d-%d' % (n, q)
        self.num_var = q
        self.dim_output = q
        self.num_cond = 2**n

        if use_mse:
            self.obs_distribution = students.GausId(n)
        else:
            self.obs_distribution = students.Bernoulli(n)
        self.link = None
Exemplo n.º 5
0
        if len(rg) > 0:
            init = np.array(rg[0]).astype(int)
        else:
            init = None
        if init is not None:
            if init > 10:
                continue
        else:
            continue

        metrics = pickle.load(open(SAVE_DIR + FOLDERS + met_files[j], 'rb'))
        args = pickle.load(open(SAVE_DIR + FOLDERS + arg_files[j], 'rb'))

        net = students.MultiGLM(
            students.Feedforward([100, N, N], ['ReLU', 'ReLU']),
            students.Feedforward([N, N_out], [None]), students.GausId(N_out))

        net.load(SAVE_DIR + FOLDERS + f)

        # if metrics['test_perf'][-1,...].min() > maxmin:
        #     maxmin = metrics['test_perf'][-1,...].min()
        #     best_net = model
        #     this_arg = args

        for key, val in metrics.items():
            if len(val) == 1000:
                continue
            if key not in all_metrics.keys():
                shp = (num, ) + np.squeeze(np.array(val)).shape
                all_metrics[key] = np.zeros(shp) * np.nan
            all_metrics[key][j, ...] = np.squeeze(val)
Exemplo n.º 6
0
                        for c,n in zip(cond_set,succ_counts)])

unscramble = np.argsort(np.argsort(succ_conds))
successor_idx = samps[unscramble]
targets = output_states[successor_idx,:]

# targets = output_state


#%%
N = 100

# net = students.Feedforward([inputs.shape[1],100,2],['ReLU',None])
net = students.MultiGLM(students.Feedforward([inputs.shape[1], N], ['ReLU']),
                        students.Feedforward([N, targets.shape[1]], [None]),
                        students.GausId(targets.shape[1]))
# net = students.MultiGLM(students.Feedforward([inputs.shape[1],N], ['ReLU']),
#                         students.Feedforward([N,targets.shape[1]], [None]),
#                         students.Bernoulli(targets.shape[1]))
# net = students.MultiGLM(students.Feedforward([inputs.shape[1],N], ['ReLU']),
#                         students.Feedforward([N, p], [None]),
#                         students.Categorical(p))

n_trn = int(0.5*targets.shape[0])   
trn = np.random.choice(targets.shape[0],n_trn,replace=False)
tst = np.random.choice(np.setdiff1d(range(targets.shape[0]),trn), int(0.5*n_trn), replace=False)

optimizer = optim.Adam(net.parameters(), lr=1e-4)
dset = torch.utils.data.TensorDataset(torch.tensor(inputs[trn,:]).float(),
                                      targets[trn,:].float())
dl = torch.utils.data.DataLoader(dset, batch_size=64, shuffle=True)
Exemplo n.º 7
0
    swap_prob = 0.15

    z_prior = None
    # z_prior = students.GausId(N)

    dec = nn.Linear(N, n_out, bias=True)
    with torch.no_grad():
        dec.weight.copy_(torch.tensor(basis[:, :n_out].T).float())
        dec.weight.requires_grad = False

    rnn = nn.RNN(n_in + N, N, 1, nonlinearity='relu')

    # net = students.GenericRNN(rnn, students.GausId(outs.shape[0]), fix_decoder=True, decoder=dec, z_dist=z_prior)
    net = students.GenericRNN(rnn,
                              students.GausId(n_out),
                              decoder=dec,
                              z_dist=students.GausId(N),
                              beta=0)
    # net = students.GenericRNN(rnn, students.GausId(outs.shape[0]), fix_decoder=False)

    # with torch.no_grad():
    #     net.rnn.inp2hid.weight.copy_(torch.tensor(basis[:,:5]).float())
    #     net.rnn.inp2hid.weight.requires_grad = False
    #     net.rnn.inp2hid.bias.requires_grad = False
    with torch.no_grad():
        inp_w = np.append(basis[:, n_out:n_out + n_in], np.eye(N), axis=-1)
        net.rnn.weight_ih_l0.copy_(torch.tensor(inp_w).float())
        net.rnn.weight_ih_l0.requires_grad = False
        # net.rnn.bias_ih_l0.requires_grad = False